diff options
| author | joshualitt <[email protected]> | 2025-08-06 10:50:02 -0700 |
|---|---|---|
| committer | GitHub <[email protected]> | 2025-08-06 17:50:02 +0000 |
| commit | 6133bea388a2de69c71a6be6f1450707f2ce4dfb (patch) | |
| tree | 367de1d618069ea80e47d7e86c4fb8f82ad032a7 /packages/core/src | |
| parent | 882a97aff998b2f19731e9966d135f1db5a59914 (diff) | |
feat(core): Introduce `DeclarativeTool` and `ToolInvocation`. (#5613)
Diffstat (limited to 'packages/core/src')
19 files changed, 708 insertions, 457 deletions
diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts index 3b6b57f9..f8b9a7de 100644 --- a/packages/core/src/core/client.ts +++ b/packages/core/src/core/client.ts @@ -24,7 +24,6 @@ import { import { Config } from '../config/config.js'; import { UserTierId } from '../code_assist/types.js'; import { getCoreSystemPrompt, getCompressionPrompt } from './prompts.js'; -import { ReadManyFilesTool } from '../tools/read-many-files.js'; import { getResponseText } from '../utils/generateContentResponseUtilities.js'; import { checkNextSpeaker } from '../utils/nextSpeakerChecker.js'; import { reportError } from '../utils/errorReporting.js'; @@ -252,18 +251,15 @@ export class GeminiClient { // Add full file context if the flag is set if (this.config.getFullContext()) { try { - const readManyFilesTool = toolRegistry.getTool( - 'read_many_files', - ) as ReadManyFilesTool; + const readManyFilesTool = toolRegistry.getTool('read_many_files'); if (readManyFilesTool) { + const invocation = readManyFilesTool.build({ + paths: ['**/*'], // Read everything recursively + useDefaultExcludes: true, // Use default excludes + }); + // Read all files in the target directory - const result = await readManyFilesTool.execute( - { - paths: ['**/*'], // Read everything recursively - useDefaultExcludes: true, // Use default excludes - }, - AbortSignal.timeout(30000), - ); + const result = await invocation.execute(AbortSignal.timeout(30000)); if (result.llmContent) { initialParts.push({ text: `\n--- Full File Context ---\n${result.llmContent}`, diff --git a/packages/core/src/core/coreToolScheduler.test.ts b/packages/core/src/core/coreToolScheduler.test.ts index 4d786d00..a65443f8 100644 --- a/packages/core/src/core/coreToolScheduler.test.ts +++ b/packages/core/src/core/coreToolScheduler.test.ts @@ -24,44 +24,15 @@ import { } from '../index.js'; import { Part, PartListUnion } from '@google/genai'; -import { ModifiableTool, ModifyContext } from '../tools/modifiable-tool.js'; - -class MockTool extends BaseTool<Record<string, unknown>, ToolResult> { - shouldConfirm = false; - executeFn = vi.fn(); - - constructor(name = 'mockTool') { - super(name, name, 'A mock tool', Icon.Hammer, {}); - } - - async shouldConfirmExecute( - _params: Record<string, unknown>, - _abortSignal: AbortSignal, - ): Promise<ToolCallConfirmationDetails | false> { - if (this.shouldConfirm) { - return { - type: 'exec', - title: 'Confirm Mock Tool', - command: 'do_thing', - rootCommand: 'do_thing', - onConfirm: async () => {}, - }; - } - return false; - } - - async execute( - params: Record<string, unknown>, - _abortSignal: AbortSignal, - ): Promise<ToolResult> { - this.executeFn(params); - return { llmContent: 'Tool executed', returnDisplay: 'Tool executed' }; - } -} +import { + ModifiableDeclarativeTool, + ModifyContext, +} from '../tools/modifiable-tool.js'; +import { MockTool } from '../test-utils/tools.js'; class MockModifiableTool extends MockTool - implements ModifiableTool<Record<string, unknown>> + implements ModifiableDeclarativeTool<Record<string, unknown>> { constructor(name = 'mockModifiableTool') { super(name); @@ -83,10 +54,7 @@ class MockModifiableTool }; } - async shouldConfirmExecute( - _params: Record<string, unknown>, - _abortSignal: AbortSignal, - ): Promise<ToolCallConfirmationDetails | false> { + async shouldConfirmExecute(): Promise<ToolCallConfirmationDetails | false> { if (this.shouldConfirm) { return { type: 'edit', @@ -107,14 +75,15 @@ describe('CoreToolScheduler', () => { it('should cancel a tool call if the signal is aborted before confirmation', async () => { const mockTool = new MockTool(); mockTool.shouldConfirm = true; + const declarativeTool = mockTool; const toolRegistry = { - getTool: () => mockTool, + getTool: () => declarativeTool, getFunctionDeclarations: () => [], tools: new Map(), discovery: {} as any, registerTool: () => {}, - getToolByName: () => mockTool, - getToolByDisplayName: () => mockTool, + getToolByName: () => declarativeTool, + getToolByDisplayName: () => declarativeTool, getTools: () => [], discoverTools: async () => {}, getAllTools: () => [], @@ -177,14 +146,15 @@ describe('CoreToolScheduler', () => { describe('CoreToolScheduler with payload', () => { it('should update args and diff and execute tool when payload is provided', async () => { const mockTool = new MockModifiableTool(); + const declarativeTool = mockTool; const toolRegistry = { - getTool: () => mockTool, + getTool: () => declarativeTool, getFunctionDeclarations: () => [], tools: new Map(), discovery: {} as any, registerTool: () => {}, - getToolByName: () => mockTool, - getToolByDisplayName: () => mockTool, + getToolByName: () => declarativeTool, + getToolByDisplayName: () => declarativeTool, getTools: () => [], discoverTools: async () => {}, getAllTools: () => [], @@ -221,10 +191,7 @@ describe('CoreToolScheduler with payload', () => { await scheduler.schedule([request], abortController.signal); - const confirmationDetails = await mockTool.shouldConfirmExecute( - {}, - abortController.signal, - ); + const confirmationDetails = await mockTool.shouldConfirmExecute(); if (confirmationDetails) { const payload: ToolConfirmationPayload = { newContent: 'final version' }; @@ -456,14 +423,15 @@ describe('CoreToolScheduler edit cancellation', () => { } const mockEditTool = new MockEditTool(); + const declarativeTool = mockEditTool; const toolRegistry = { - getTool: () => mockEditTool, + getTool: () => declarativeTool, getFunctionDeclarations: () => [], tools: new Map(), discovery: {} as any, registerTool: () => {}, - getToolByName: () => mockEditTool, - getToolByDisplayName: () => mockEditTool, + getToolByName: () => declarativeTool, + getToolByDisplayName: () => declarativeTool, getTools: () => [], discoverTools: async () => {}, getAllTools: () => [], @@ -541,18 +509,23 @@ describe('CoreToolScheduler YOLO mode', () => { it('should execute tool requiring confirmation directly without waiting', async () => { // Arrange const mockTool = new MockTool(); + mockTool.executeFn.mockReturnValue({ + llmContent: 'Tool executed', + returnDisplay: 'Tool executed', + }); // This tool would normally require confirmation. mockTool.shouldConfirm = true; + const declarativeTool = mockTool; const toolRegistry = { - getTool: () => mockTool, - getToolByName: () => mockTool, + getTool: () => declarativeTool, + getToolByName: () => declarativeTool, // Other properties are not needed for this test but are included for type consistency. getFunctionDeclarations: () => [], tools: new Map(), discovery: {} as any, registerTool: () => {}, - getToolByDisplayName: () => mockTool, + getToolByDisplayName: () => declarativeTool, getTools: () => [], discoverTools: async () => {}, getAllTools: () => [], diff --git a/packages/core/src/core/coreToolScheduler.ts b/packages/core/src/core/coreToolScheduler.ts index 9b999b6b..6f098ae3 100644 --- a/packages/core/src/core/coreToolScheduler.ts +++ b/packages/core/src/core/coreToolScheduler.ts @@ -8,7 +8,6 @@ import { ToolCallRequestInfo, ToolCallResponseInfo, ToolConfirmationOutcome, - Tool, ToolCallConfirmationDetails, ToolResult, ToolResultDisplay, @@ -20,11 +19,13 @@ import { ToolCallEvent, ToolConfirmationPayload, ToolErrorType, + AnyDeclarativeTool, + AnyToolInvocation, } from '../index.js'; import { Part, PartListUnion } from '@google/genai'; import { getResponseTextFromParts } from '../utils/generateContentResponseUtilities.js'; import { - isModifiableTool, + isModifiableDeclarativeTool, ModifyContext, modifyWithEditor, } from '../tools/modifiable-tool.js'; @@ -33,7 +34,8 @@ import * as Diff from 'diff'; export type ValidatingToolCall = { status: 'validating'; request: ToolCallRequestInfo; - tool: Tool; + tool: AnyDeclarativeTool; + invocation: AnyToolInvocation; startTime?: number; outcome?: ToolConfirmationOutcome; }; @@ -41,7 +43,8 @@ export type ValidatingToolCall = { export type ScheduledToolCall = { status: 'scheduled'; request: ToolCallRequestInfo; - tool: Tool; + tool: AnyDeclarativeTool; + invocation: AnyToolInvocation; startTime?: number; outcome?: ToolConfirmationOutcome; }; @@ -50,6 +53,7 @@ export type ErroredToolCall = { status: 'error'; request: ToolCallRequestInfo; response: ToolCallResponseInfo; + tool?: AnyDeclarativeTool; durationMs?: number; outcome?: ToolConfirmationOutcome; }; @@ -57,8 +61,9 @@ export type ErroredToolCall = { export type SuccessfulToolCall = { status: 'success'; request: ToolCallRequestInfo; - tool: Tool; + tool: AnyDeclarativeTool; response: ToolCallResponseInfo; + invocation: AnyToolInvocation; durationMs?: number; outcome?: ToolConfirmationOutcome; }; @@ -66,7 +71,8 @@ export type SuccessfulToolCall = { export type ExecutingToolCall = { status: 'executing'; request: ToolCallRequestInfo; - tool: Tool; + tool: AnyDeclarativeTool; + invocation: AnyToolInvocation; liveOutput?: string; startTime?: number; outcome?: ToolConfirmationOutcome; @@ -76,7 +82,8 @@ export type CancelledToolCall = { status: 'cancelled'; request: ToolCallRequestInfo; response: ToolCallResponseInfo; - tool: Tool; + tool: AnyDeclarativeTool; + invocation: AnyToolInvocation; durationMs?: number; outcome?: ToolConfirmationOutcome; }; @@ -84,7 +91,8 @@ export type CancelledToolCall = { export type WaitingToolCall = { status: 'awaiting_approval'; request: ToolCallRequestInfo; - tool: Tool; + tool: AnyDeclarativeTool; + invocation: AnyToolInvocation; confirmationDetails: ToolCallConfirmationDetails; startTime?: number; outcome?: ToolConfirmationOutcome; @@ -289,6 +297,7 @@ export class CoreToolScheduler { // currentCall is a non-terminal state here and should have startTime and tool. const existingStartTime = currentCall.startTime; const toolInstance = currentCall.tool; + const invocation = currentCall.invocation; const outcome = currentCall.outcome; @@ -300,6 +309,7 @@ export class CoreToolScheduler { return { request: currentCall.request, tool: toolInstance, + invocation, status: 'success', response: auxiliaryData as ToolCallResponseInfo, durationMs, @@ -313,6 +323,7 @@ export class CoreToolScheduler { return { request: currentCall.request, status: 'error', + tool: toolInstance, response: auxiliaryData as ToolCallResponseInfo, durationMs, outcome, @@ -326,6 +337,7 @@ export class CoreToolScheduler { confirmationDetails: auxiliaryData as ToolCallConfirmationDetails, startTime: existingStartTime, outcome, + invocation, } as WaitingToolCall; case 'scheduled': return { @@ -334,6 +346,7 @@ export class CoreToolScheduler { status: 'scheduled', startTime: existingStartTime, outcome, + invocation, } as ScheduledToolCall; case 'cancelled': { const durationMs = existingStartTime @@ -358,6 +371,7 @@ export class CoreToolScheduler { return { request: currentCall.request, tool: toolInstance, + invocation, status: 'cancelled', response: { callId: currentCall.request.callId, @@ -385,6 +399,7 @@ export class CoreToolScheduler { status: 'validating', startTime: existingStartTime, outcome, + invocation, } as ValidatingToolCall; case 'executing': return { @@ -393,6 +408,7 @@ export class CoreToolScheduler { status: 'executing', startTime: existingStartTime, outcome, + invocation, } as ExecutingToolCall; default: { const exhaustiveCheck: never = newStatus; @@ -406,10 +422,34 @@ export class CoreToolScheduler { private setArgsInternal(targetCallId: string, args: unknown): void { this.toolCalls = this.toolCalls.map((call) => { - if (call.request.callId !== targetCallId) return call; + // We should never be asked to set args on an ErroredToolCall, but + // we guard for the case anyways. + if (call.request.callId !== targetCallId || call.status === 'error') { + return call; + } + + const invocationOrError = this.buildInvocation( + call.tool, + args as Record<string, unknown>, + ); + if (invocationOrError instanceof Error) { + const response = createErrorResponse( + call.request, + invocationOrError, + ToolErrorType.INVALID_TOOL_PARAMS, + ); + return { + request: { ...call.request, args: args as Record<string, unknown> }, + status: 'error', + tool: call.tool, + response, + } as ErroredToolCall; + } + return { ...call, request: { ...call.request, args: args as Record<string, unknown> }, + invocation: invocationOrError, }; }); } @@ -421,6 +461,20 @@ export class CoreToolScheduler { ); } + private buildInvocation( + tool: AnyDeclarativeTool, + args: object, + ): AnyToolInvocation | Error { + try { + return tool.build(args); + } catch (e) { + if (e instanceof Error) { + return e; + } + return new Error(String(e)); + } + } + async schedule( request: ToolCallRequestInfo | ToolCallRequestInfo[], signal: AbortSignal, @@ -448,10 +502,30 @@ export class CoreToolScheduler { durationMs: 0, }; } + + const invocationOrError = this.buildInvocation( + toolInstance, + reqInfo.args, + ); + if (invocationOrError instanceof Error) { + return { + status: 'error', + request: reqInfo, + tool: toolInstance, + response: createErrorResponse( + reqInfo, + invocationOrError, + ToolErrorType.INVALID_TOOL_PARAMS, + ), + durationMs: 0, + }; + } + return { status: 'validating', request: reqInfo, tool: toolInstance, + invocation: invocationOrError, startTime: Date.now(), }; }, @@ -465,7 +539,8 @@ export class CoreToolScheduler { continue; } - const { request: reqInfo, tool: toolInstance } = toolCall; + const { request: reqInfo, invocation } = toolCall; + try { if (this.config.getApprovalMode() === ApprovalMode.YOLO) { this.setToolCallOutcome( @@ -474,10 +549,8 @@ export class CoreToolScheduler { ); this.setStatusInternal(reqInfo.callId, 'scheduled'); } else { - const confirmationDetails = await toolInstance.shouldConfirmExecute( - reqInfo.args, - signal, - ); + const confirmationDetails = + await invocation.shouldConfirmExecute(signal); if (confirmationDetails) { // Allow IDE to resolve confirmation @@ -573,7 +646,7 @@ export class CoreToolScheduler { ); } else if (outcome === ToolConfirmationOutcome.ModifyWithEditor) { const waitingToolCall = toolCall as WaitingToolCall; - if (isModifiableTool(waitingToolCall.tool)) { + if (isModifiableDeclarativeTool(waitingToolCall.tool)) { const modifyContext = waitingToolCall.tool.getModifyContext(signal); const editorType = this.getPreferredEditor(); if (!editorType) { @@ -628,7 +701,7 @@ export class CoreToolScheduler { ): Promise<void> { if ( toolCall.confirmationDetails.type !== 'edit' || - !isModifiableTool(toolCall.tool) + !isModifiableDeclarativeTool(toolCall.tool) ) { return; } @@ -677,6 +750,7 @@ export class CoreToolScheduler { const scheduledCall = toolCall; const { callId, name: toolName } = scheduledCall.request; + const invocation = scheduledCall.invocation; this.setStatusInternal(callId, 'executing'); const liveOutputCallback = @@ -694,8 +768,8 @@ export class CoreToolScheduler { } : undefined; - scheduledCall.tool - .execute(scheduledCall.request.args, signal, liveOutputCallback) + invocation + .execute(signal, liveOutputCallback) .then(async (toolResult: ToolResult) => { if (signal.aborted) { this.setStatusInternal( diff --git a/packages/core/src/core/nonInteractiveToolExecutor.test.ts b/packages/core/src/core/nonInteractiveToolExecutor.test.ts index 1bbb9209..b0ed7107 100644 --- a/packages/core/src/core/nonInteractiveToolExecutor.test.ts +++ b/packages/core/src/core/nonInteractiveToolExecutor.test.ts @@ -10,12 +10,10 @@ import { ToolRegistry, ToolCallRequestInfo, ToolResult, - Tool, - ToolCallConfirmationDetails, Config, - Icon, } from '../index.js'; -import { Part, Type } from '@google/genai'; +import { Part } from '@google/genai'; +import { MockTool } from '../test-utils/tools.js'; const mockConfig = { getSessionId: () => 'test-session-id', @@ -25,36 +23,11 @@ const mockConfig = { describe('executeToolCall', () => { let mockToolRegistry: ToolRegistry; - let mockTool: Tool; + let mockTool: MockTool; let abortController: AbortController; beforeEach(() => { - mockTool = { - name: 'testTool', - displayName: 'Test Tool', - description: 'A tool for testing', - icon: Icon.Hammer, - schema: { - name: 'testTool', - description: 'A tool for testing', - parameters: { - type: Type.OBJECT, - properties: { - param1: { type: Type.STRING }, - }, - required: ['param1'], - }, - }, - execute: vi.fn(), - validateToolParams: vi.fn(() => null), - shouldConfirmExecute: vi.fn(() => - Promise.resolve(false as false | ToolCallConfirmationDetails), - ), - isOutputMarkdown: false, - canUpdateOutput: false, - getDescription: vi.fn(), - toolLocations: vi.fn(() => []), - }; + mockTool = new MockTool(); mockToolRegistry = { getTool: vi.fn(), @@ -77,7 +50,7 @@ describe('executeToolCall', () => { returnDisplay: 'Success!', }; vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockTool); - vi.mocked(mockTool.execute).mockResolvedValue(toolResult); + vi.spyOn(mockTool, 'buildAndExecute').mockResolvedValue(toolResult); const response = await executeToolCall( mockConfig, @@ -87,7 +60,7 @@ describe('executeToolCall', () => { ); expect(mockToolRegistry.getTool).toHaveBeenCalledWith('testTool'); - expect(mockTool.execute).toHaveBeenCalledWith( + expect(mockTool.buildAndExecute).toHaveBeenCalledWith( request.args, abortController.signal, ); @@ -149,7 +122,7 @@ describe('executeToolCall', () => { }; const executionError = new Error('Tool execution failed'); vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockTool); - vi.mocked(mockTool.execute).mockRejectedValue(executionError); + vi.spyOn(mockTool, 'buildAndExecute').mockRejectedValue(executionError); const response = await executeToolCall( mockConfig, @@ -183,25 +156,27 @@ describe('executeToolCall', () => { const cancellationError = new Error('Operation cancelled'); vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockTool); - vi.mocked(mockTool.execute).mockImplementation(async (_args, signal) => { - if (signal?.aborted) { - return Promise.reject(cancellationError); - } - return new Promise((_resolve, reject) => { - signal?.addEventListener('abort', () => { - reject(cancellationError); + vi.spyOn(mockTool, 'buildAndExecute').mockImplementation( + async (_args, signal) => { + if (signal?.aborted) { + return Promise.reject(cancellationError); + } + return new Promise((_resolve, reject) => { + signal?.addEventListener('abort', () => { + reject(cancellationError); + }); + // Simulate work that might happen if not aborted immediately + const timeoutId = setTimeout( + () => + reject( + new Error('Should have been cancelled if not aborted prior'), + ), + 100, + ); + signal?.addEventListener('abort', () => clearTimeout(timeoutId)); }); - // Simulate work that might happen if not aborted immediately - const timeoutId = setTimeout( - () => - reject( - new Error('Should have been cancelled if not aborted prior'), - ), - 100, - ); - signal?.addEventListener('abort', () => clearTimeout(timeoutId)); - }); - }); + }, + ); abortController.abort(); // Abort before calling const response = await executeToolCall( @@ -232,7 +207,7 @@ describe('executeToolCall', () => { returnDisplay: 'Image processed', }; vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockTool); - vi.mocked(mockTool.execute).mockResolvedValue(toolResult); + vi.spyOn(mockTool, 'buildAndExecute').mockResolvedValue(toolResult); const response = await executeToolCall( mockConfig, diff --git a/packages/core/src/core/nonInteractiveToolExecutor.ts b/packages/core/src/core/nonInteractiveToolExecutor.ts index ed235cd3..43061f83 100644 --- a/packages/core/src/core/nonInteractiveToolExecutor.ts +++ b/packages/core/src/core/nonInteractiveToolExecutor.ts @@ -65,7 +65,7 @@ export async function executeToolCall( try { // Directly execute without confirmation or live output handling const effectiveAbortSignal = abortSignal ?? new AbortController().signal; - const toolResult: ToolResult = await tool.execute( + const toolResult: ToolResult = await tool.buildAndExecute( toolCallRequest.args, effectiveAbortSignal, // No live output callback for non-interactive mode diff --git a/packages/core/src/telemetry/loggers.test.circular.ts b/packages/core/src/telemetry/loggers.test.circular.ts index 80444a0d..3cf85e46 100644 --- a/packages/core/src/telemetry/loggers.test.circular.ts +++ b/packages/core/src/telemetry/loggers.test.circular.ts @@ -14,7 +14,7 @@ import { ToolCallEvent } from './types.js'; import { Config } from '../config/config.js'; import { CompletedToolCall } from '../core/coreToolScheduler.js'; import { ToolCallRequestInfo, ToolCallResponseInfo } from '../core/turn.js'; -import { Tool } from '../tools/tools.js'; +import { MockTool } from '../test-utils/tools.js'; describe('Circular Reference Handling', () => { it('should handle circular references in tool function arguments', () => { @@ -56,11 +56,13 @@ describe('Circular Reference Handling', () => { errorType: undefined, }; + const tool = new MockTool('mock-tool'); const mockCompletedToolCall: CompletedToolCall = { status: 'success', request: mockRequest, response: mockResponse, - tool: {} as Tool, + tool, + invocation: tool.build({}), durationMs: 100, }; @@ -104,11 +106,13 @@ describe('Circular Reference Handling', () => { errorType: undefined, }; + const tool = new MockTool('mock-tool'); const mockCompletedToolCall: CompletedToolCall = { status: 'success', request: mockRequest, response: mockResponse, - tool: {} as Tool, + tool, + invocation: tool.build({}), durationMs: 100, }; diff --git a/packages/core/src/telemetry/loggers.test.ts b/packages/core/src/telemetry/loggers.test.ts index 3d8116cc..14de83a9 100644 --- a/packages/core/src/telemetry/loggers.test.ts +++ b/packages/core/src/telemetry/loggers.test.ts @@ -5,6 +5,7 @@ */ import { + AnyToolInvocation, AuthType, CompletedToolCall, ContentGeneratorConfig, @@ -432,6 +433,7 @@ describe('loggers', () => { }); it('should log a tool call with all fields', () => { + const tool = new EditTool(mockConfig); const call: CompletedToolCall = { status: 'success', request: { @@ -451,7 +453,8 @@ describe('loggers', () => { error: undefined, errorType: undefined, }, - tool: new EditTool(mockConfig), + tool, + invocation: {} as AnyToolInvocation, durationMs: 100, outcome: ToolConfirmationOutcome.ProceedOnce, }; @@ -581,6 +584,7 @@ describe('loggers', () => { }, outcome: ToolConfirmationOutcome.ModifyWithEditor, tool: new EditTool(mockConfig), + invocation: {} as AnyToolInvocation, durationMs: 100, }; const event = new ToolCallEvent(call); @@ -645,6 +649,7 @@ describe('loggers', () => { errorType: undefined, }, tool: new EditTool(mockConfig), + invocation: {} as AnyToolInvocation, durationMs: 100, }; const event = new ToolCallEvent(call); diff --git a/packages/core/src/telemetry/uiTelemetry.test.ts b/packages/core/src/telemetry/uiTelemetry.test.ts index 221804d2..ac9727f1 100644 --- a/packages/core/src/telemetry/uiTelemetry.test.ts +++ b/packages/core/src/telemetry/uiTelemetry.test.ts @@ -23,7 +23,8 @@ import { SuccessfulToolCall, } from '../core/coreToolScheduler.js'; import { ToolErrorType } from '../tools/tool-error.js'; -import { Tool, ToolConfirmationOutcome } from '../tools/tools.js'; +import { ToolConfirmationOutcome } from '../tools/tools.js'; +import { MockTool } from '../test-utils/tools.js'; const createFakeCompletedToolCall = ( name: string, @@ -39,12 +40,14 @@ const createFakeCompletedToolCall = ( isClientInitiated: false, prompt_id: 'prompt-id-1', }; + const tool = new MockTool(name); if (success) { return { status: 'success', request, - tool: { name } as Tool, // Mock tool + tool, + invocation: tool.build({}), response: { callId: request.callId, responseParts: { @@ -65,6 +68,7 @@ const createFakeCompletedToolCall = ( return { status: 'error', request, + tool, response: { callId: request.callId, responseParts: { diff --git a/packages/core/src/test-utils/tools.ts b/packages/core/src/test-utils/tools.ts new file mode 100644 index 00000000..b168db9c --- /dev/null +++ b/packages/core/src/test-utils/tools.ts @@ -0,0 +1,63 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { vi } from 'vitest'; +import { + BaseTool, + Icon, + ToolCallConfirmationDetails, + ToolResult, +} from '../tools/tools.js'; +import { Schema, Type } from '@google/genai'; + +/** + * A highly configurable mock tool for testing purposes. + */ +export class MockTool extends BaseTool<{ [key: string]: unknown }, ToolResult> { + executeFn = vi.fn(); + shouldConfirm = false; + + constructor( + name = 'mock-tool', + displayName?: string, + description = 'A mock tool for testing.', + params: Schema = { + type: Type.OBJECT, + properties: { param: { type: Type.STRING } }, + }, + ) { + super(name, displayName ?? name, description, Icon.Hammer, params); + } + + async execute( + params: { [key: string]: unknown }, + _abortSignal: AbortSignal, + ): Promise<ToolResult> { + const result = this.executeFn(params); + return ( + result ?? { + llmContent: `Tool ${this.name} executed successfully.`, + returnDisplay: `Tool ${this.name} executed successfully.`, + } + ); + } + + async shouldConfirmExecute( + _params: { [key: string]: unknown }, + _abortSignal: AbortSignal, + ): Promise<ToolCallConfirmationDetails | false> { + if (this.shouldConfirm) { + return { + type: 'exec' as const, + title: `Confirm ${this.displayName}`, + command: this.name, + rootCommand: this.name, + onConfirm: async () => {}, + }; + } + return false; + } +} diff --git a/packages/core/src/tools/edit.ts b/packages/core/src/tools/edit.ts index 0d129e42..853ad4c1 100644 --- a/packages/core/src/tools/edit.ts +++ b/packages/core/src/tools/edit.ts @@ -26,7 +26,7 @@ import { Config, ApprovalMode } from '../config/config.js'; import { ensureCorrectEdit } from '../utils/editCorrector.js'; import { DEFAULT_DIFF_OPTIONS } from './diffOptions.js'; import { ReadFileTool } from './read-file.js'; -import { ModifiableTool, ModifyContext } from './modifiable-tool.js'; +import { ModifiableDeclarativeTool, ModifyContext } from './modifiable-tool.js'; /** * Parameters for the Edit tool @@ -72,7 +72,7 @@ interface CalculatedEdit { */ export class EditTool extends BaseTool<EditToolParams, ToolResult> - implements ModifiableTool<EditToolParams> + implements ModifiableDeclarativeTool<EditToolParams> { static readonly Name = 'replace'; diff --git a/packages/core/src/tools/memoryTool.ts b/packages/core/src/tools/memoryTool.ts index 847ea5cf..f3bf315b 100644 --- a/packages/core/src/tools/memoryTool.ts +++ b/packages/core/src/tools/memoryTool.ts @@ -18,7 +18,7 @@ import { homedir } from 'os'; import * as Diff from 'diff'; import { DEFAULT_DIFF_OPTIONS } from './diffOptions.js'; import { tildeifyPath } from '../utils/paths.js'; -import { ModifiableTool, ModifyContext } from './modifiable-tool.js'; +import { ModifiableDeclarativeTool, ModifyContext } from './modifiable-tool.js'; const memoryToolSchemaData: FunctionDeclaration = { name: 'save_memory', @@ -112,7 +112,7 @@ function ensureNewlineSeparation(currentContent: string): string { export class MemoryTool extends BaseTool<SaveMemoryParams, ToolResult> - implements ModifiableTool<SaveMemoryParams> + implements ModifiableDeclarativeTool<SaveMemoryParams> { private static readonly allowlist: Set<string> = new Set(); diff --git a/packages/core/src/tools/modifiable-tool.test.ts b/packages/core/src/tools/modifiable-tool.test.ts index eb7e8dbf..dc68640a 100644 --- a/packages/core/src/tools/modifiable-tool.test.ts +++ b/packages/core/src/tools/modifiable-tool.test.ts @@ -8,8 +8,8 @@ import { vi, describe, it, expect, beforeEach, afterEach } from 'vitest'; import { modifyWithEditor, ModifyContext, - ModifiableTool, - isModifiableTool, + ModifiableDeclarativeTool, + isModifiableDeclarativeTool, } from './modifiable-tool.js'; import { EditorType } from '../utils/editor.js'; import fs from 'fs'; @@ -338,16 +338,16 @@ describe('isModifiableTool', () => { const mockTool = { name: 'test-tool', getModifyContext: vi.fn(), - } as unknown as ModifiableTool<TestParams>; + } as unknown as ModifiableDeclarativeTool<TestParams>; - expect(isModifiableTool(mockTool)).toBe(true); + expect(isModifiableDeclarativeTool(mockTool)).toBe(true); }); it('should return false for objects without getModifyContext method', () => { const mockTool = { name: 'test-tool', - } as unknown as ModifiableTool<TestParams>; + } as unknown as ModifiableDeclarativeTool<TestParams>; - expect(isModifiableTool(mockTool)).toBe(false); + expect(isModifiableDeclarativeTool(mockTool)).toBe(false); }); }); diff --git a/packages/core/src/tools/modifiable-tool.ts b/packages/core/src/tools/modifiable-tool.ts index 42de3eb6..25a2906b 100644 --- a/packages/core/src/tools/modifiable-tool.ts +++ b/packages/core/src/tools/modifiable-tool.ts @@ -11,13 +11,14 @@ import fs from 'fs'; import * as Diff from 'diff'; import { DEFAULT_DIFF_OPTIONS } from './diffOptions.js'; import { isNodeError } from '../utils/errors.js'; -import { Tool } from './tools.js'; +import { AnyDeclarativeTool, DeclarativeTool, ToolResult } from './tools.js'; /** - * A tool that supports a modify operation. + * A declarative tool that supports a modify operation. */ -export interface ModifiableTool<ToolParams> extends Tool<ToolParams> { - getModifyContext(abortSignal: AbortSignal): ModifyContext<ToolParams>; +export interface ModifiableDeclarativeTool<TParams extends object> + extends DeclarativeTool<TParams, ToolResult> { + getModifyContext(abortSignal: AbortSignal): ModifyContext<TParams>; } export interface ModifyContext<ToolParams> { @@ -39,9 +40,12 @@ export interface ModifyResult<ToolParams> { updatedDiff: string; } -export function isModifiableTool<TParams>( - tool: Tool<TParams>, -): tool is ModifiableTool<TParams> { +/** + * Type guard to check if a declarative tool is modifiable. + */ +export function isModifiableDeclarativeTool( + tool: AnyDeclarativeTool, +): tool is ModifiableDeclarativeTool<object> { return 'getModifyContext' in tool; } diff --git a/packages/core/src/tools/read-file.test.ts b/packages/core/src/tools/read-file.test.ts index fa1e458c..bb9317fd 100644 --- a/packages/core/src/tools/read-file.test.ts +++ b/packages/core/src/tools/read-file.test.ts @@ -13,6 +13,7 @@ import fsp from 'fs/promises'; import { Config } from '../config/config.js'; import { FileDiscoveryService } from '../services/fileDiscoveryService.js'; import { createMockWorkspaceContext } from '../test-utils/mockWorkspaceContext.js'; +import { ToolInvocation, ToolResult } from './tools.js'; describe('ReadFileTool', () => { let tempRootDir: string; @@ -40,57 +41,62 @@ describe('ReadFileTool', () => { } }); - describe('validateToolParams', () => { - it('should return null for valid params (absolute path within root)', () => { + describe('build', () => { + it('should return an invocation for valid params (absolute path within root)', () => { const params: ReadFileToolParams = { absolute_path: path.join(tempRootDir, 'test.txt'), }; - expect(tool.validateToolParams(params)).toBeNull(); + const result = tool.build(params); + expect(result).not.toBeTypeOf('string'); + expect(typeof result).toBe('object'); + expect( + (result as ToolInvocation<ReadFileToolParams, ToolResult>).params, + ).toEqual(params); }); - it('should return null for valid params with offset and limit', () => { + it('should return an invocation for valid params with offset and limit', () => { const params: ReadFileToolParams = { absolute_path: path.join(tempRootDir, 'test.txt'), offset: 0, limit: 10, }; - expect(tool.validateToolParams(params)).toBeNull(); + const result = tool.build(params); + expect(result).not.toBeTypeOf('string'); }); - it('should return error for relative path', () => { + it('should throw error for relative path', () => { const params: ReadFileToolParams = { absolute_path: 'test.txt' }; - expect(tool.validateToolParams(params)).toBe( + expect(() => tool.build(params)).toThrow( `File path must be absolute, but was relative: test.txt. You must provide an absolute path.`, ); }); - it('should return error for path outside root', () => { + it('should throw error for path outside root', () => { const outsidePath = path.resolve(os.tmpdir(), 'outside-root.txt'); const params: ReadFileToolParams = { absolute_path: outsidePath }; - const error = tool.validateToolParams(params); - expect(error).toContain( + expect(() => tool.build(params)).toThrow( 'File path must be within one of the workspace directories', ); }); - it('should return error for negative offset', () => { + it('should throw error for negative offset', () => { const params: ReadFileToolParams = { absolute_path: path.join(tempRootDir, 'test.txt'), offset: -1, limit: 10, }; - expect(tool.validateToolParams(params)).toBe( + expect(() => tool.build(params)).toThrow( 'Offset must be a non-negative number', ); }); - it('should return error for non-positive limit', () => { + it('should throw error for non-positive limit', () => { const paramsZero: ReadFileToolParams = { absolute_path: path.join(tempRootDir, 'test.txt'), offset: 0, limit: 0, }; - expect(tool.validateToolParams(paramsZero)).toBe( + expect(() => tool.build(paramsZero)).toThrow( 'Limit must be a positive number', ); const paramsNegative: ReadFileToolParams = { @@ -98,168 +104,182 @@ describe('ReadFileTool', () => { offset: 0, limit: -5, }; - expect(tool.validateToolParams(paramsNegative)).toBe( + expect(() => tool.build(paramsNegative)).toThrow( 'Limit must be a positive number', ); }); - it('should return error for schema validation failure (e.g. missing path)', () => { + it('should throw error for schema validation failure (e.g. missing path)', () => { const params = { offset: 0 } as unknown as ReadFileToolParams; - expect(tool.validateToolParams(params)).toBe( + expect(() => tool.build(params)).toThrow( `params must have required property 'absolute_path'`, ); }); }); - describe('getDescription', () => { - it('should return a shortened, relative path', () => { - const filePath = path.join(tempRootDir, 'sub', 'dir', 'file.txt'); - const params: ReadFileToolParams = { absolute_path: filePath }; - expect(tool.getDescription(params)).toBe( - path.join('sub', 'dir', 'file.txt'), - ); - }); - - it('should return . if path is the root directory', () => { - const params: ReadFileToolParams = { absolute_path: tempRootDir }; - expect(tool.getDescription(params)).toBe('.'); - }); - }); + describe('ToolInvocation', () => { + describe('getDescription', () => { + it('should return a shortened, relative path', () => { + const filePath = path.join(tempRootDir, 'sub', 'dir', 'file.txt'); + const params: ReadFileToolParams = { absolute_path: filePath }; + const invocation = tool.build(params); + expect(typeof invocation).not.toBe('string'); + expect( + ( + invocation as ToolInvocation<ReadFileToolParams, ToolResult> + ).getDescription(), + ).toBe(path.join('sub', 'dir', 'file.txt')); + }); - describe('execute', () => { - it('should return validation error if params are invalid', async () => { - const params: ReadFileToolParams = { - absolute_path: 'relative/path.txt', - }; - expect(await tool.execute(params, abortSignal)).toEqual({ - llmContent: - 'Error: Invalid parameters provided. Reason: File path must be absolute, but was relative: relative/path.txt. You must provide an absolute path.', - returnDisplay: - 'File path must be absolute, but was relative: relative/path.txt. You must provide an absolute path.', + it('should return . if path is the root directory', () => { + const params: ReadFileToolParams = { absolute_path: tempRootDir }; + const invocation = tool.build(params); + expect(typeof invocation).not.toBe('string'); + expect( + ( + invocation as ToolInvocation<ReadFileToolParams, ToolResult> + ).getDescription(), + ).toBe('.'); }); }); - it('should return error if file does not exist', async () => { - const filePath = path.join(tempRootDir, 'nonexistent.txt'); - const params: ReadFileToolParams = { absolute_path: filePath }; + describe('execute', () => { + it('should return error if file does not exist', async () => { + const filePath = path.join(tempRootDir, 'nonexistent.txt'); + const params: ReadFileToolParams = { absolute_path: filePath }; + const invocation = tool.build(params) as ToolInvocation< + ReadFileToolParams, + ToolResult + >; - expect(await tool.execute(params, abortSignal)).toEqual({ - llmContent: `File not found: ${filePath}`, - returnDisplay: 'File not found.', + expect(await invocation.execute(abortSignal)).toEqual({ + llmContent: `File not found: ${filePath}`, + returnDisplay: 'File not found.', + }); }); - }); - it('should return success result for a text file', async () => { - const filePath = path.join(tempRootDir, 'textfile.txt'); - const fileContent = 'This is a test file.'; - await fsp.writeFile(filePath, fileContent, 'utf-8'); - const params: ReadFileToolParams = { absolute_path: filePath }; + it('should return success result for a text file', async () => { + const filePath = path.join(tempRootDir, 'textfile.txt'); + const fileContent = 'This is a test file.'; + await fsp.writeFile(filePath, fileContent, 'utf-8'); + const params: ReadFileToolParams = { absolute_path: filePath }; + const invocation = tool.build(params) as ToolInvocation< + ReadFileToolParams, + ToolResult + >; - expect(await tool.execute(params, abortSignal)).toEqual({ - llmContent: fileContent, - returnDisplay: '', + expect(await invocation.execute(abortSignal)).toEqual({ + llmContent: fileContent, + returnDisplay: '', + }); }); - }); - it('should return success result for an image file', async () => { - // A minimal 1x1 transparent PNG file. - const pngContent = Buffer.from([ - 137, 80, 78, 71, 13, 10, 26, 10, 0, 0, 0, 13, 73, 72, 68, 82, 0, 0, 0, - 1, 0, 0, 0, 1, 8, 6, 0, 0, 0, 31, 21, 196, 137, 0, 0, 0, 10, 73, 68, 65, - 84, 120, 156, 99, 0, 1, 0, 0, 5, 0, 1, 13, 10, 45, 180, 0, 0, 0, 0, 73, - 69, 78, 68, 174, 66, 96, 130, - ]); - const filePath = path.join(tempRootDir, 'image.png'); - await fsp.writeFile(filePath, pngContent); - const params: ReadFileToolParams = { absolute_path: filePath }; + it('should return success result for an image file', async () => { + // A minimal 1x1 transparent PNG file. + const pngContent = Buffer.from([ + 137, 80, 78, 71, 13, 10, 26, 10, 0, 0, 0, 13, 73, 72, 68, 82, 0, 0, 0, + 1, 0, 0, 0, 1, 8, 6, 0, 0, 0, 31, 21, 196, 137, 0, 0, 0, 10, 73, 68, + 65, 84, 120, 156, 99, 0, 1, 0, 0, 5, 0, 1, 13, 10, 45, 180, 0, 0, 0, + 0, 73, 69, 78, 68, 174, 66, 96, 130, + ]); + const filePath = path.join(tempRootDir, 'image.png'); + await fsp.writeFile(filePath, pngContent); + const params: ReadFileToolParams = { absolute_path: filePath }; + const invocation = tool.build(params) as ToolInvocation< + ReadFileToolParams, + ToolResult + >; - expect(await tool.execute(params, abortSignal)).toEqual({ - llmContent: { - inlineData: { - mimeType: 'image/png', - data: pngContent.toString('base64'), + expect(await invocation.execute(abortSignal)).toEqual({ + llmContent: { + inlineData: { + mimeType: 'image/png', + data: pngContent.toString('base64'), + }, }, - }, - returnDisplay: `Read image file: image.png`, + returnDisplay: `Read image file: image.png`, + }); }); - }); - it('should treat a non-image file with image extension as an image', async () => { - const filePath = path.join(tempRootDir, 'fake-image.png'); - const fileContent = 'This is not a real png.'; - await fsp.writeFile(filePath, fileContent, 'utf-8'); - const params: ReadFileToolParams = { absolute_path: filePath }; + it('should treat a non-image file with image extension as an image', async () => { + const filePath = path.join(tempRootDir, 'fake-image.png'); + const fileContent = 'This is not a real png.'; + await fsp.writeFile(filePath, fileContent, 'utf-8'); + const params: ReadFileToolParams = { absolute_path: filePath }; + const invocation = tool.build(params) as ToolInvocation< + ReadFileToolParams, + ToolResult + >; - expect(await tool.execute(params, abortSignal)).toEqual({ - llmContent: { - inlineData: { - mimeType: 'image/png', - data: Buffer.from(fileContent).toString('base64'), + expect(await invocation.execute(abortSignal)).toEqual({ + llmContent: { + inlineData: { + mimeType: 'image/png', + data: Buffer.from(fileContent).toString('base64'), + }, }, - }, - returnDisplay: `Read image file: fake-image.png`, + returnDisplay: `Read image file: fake-image.png`, + }); }); - }); - it('should pass offset and limit to read a slice of a text file', async () => { - const filePath = path.join(tempRootDir, 'paginated.txt'); - const fileContent = Array.from( - { length: 20 }, - (_, i) => `Line ${i + 1}`, - ).join('\n'); - await fsp.writeFile(filePath, fileContent, 'utf-8'); + it('should pass offset and limit to read a slice of a text file', async () => { + const filePath = path.join(tempRootDir, 'paginated.txt'); + const fileContent = Array.from( + { length: 20 }, + (_, i) => `Line ${i + 1}`, + ).join('\n'); + await fsp.writeFile(filePath, fileContent, 'utf-8'); - const params: ReadFileToolParams = { - absolute_path: filePath, - offset: 5, // Start from line 6 - limit: 3, - }; + const params: ReadFileToolParams = { + absolute_path: filePath, + offset: 5, // Start from line 6 + limit: 3, + }; + const invocation = tool.build(params) as ToolInvocation< + ReadFileToolParams, + ToolResult + >; - expect(await tool.execute(params, abortSignal)).toEqual({ - llmContent: [ - '[File content truncated: showing lines 6-8 of 20 total lines. Use offset/limit parameters to view more.]', - 'Line 6', - 'Line 7', - 'Line 8', - ].join('\n'), - returnDisplay: 'Read lines 6-8 of 20 from paginated.txt', + expect(await invocation.execute(abortSignal)).toEqual({ + llmContent: [ + '[File content truncated: showing lines 6-8 of 20 total lines. Use offset/limit parameters to view more.]', + 'Line 6', + 'Line 7', + 'Line 8', + ].join('\n'), + returnDisplay: 'Read lines 6-8 of 20 from paginated.txt', + }); }); - }); - describe('with .geminiignore', () => { - beforeEach(async () => { - await fsp.writeFile( - path.join(tempRootDir, '.geminiignore'), - ['foo.*', 'ignored/'].join('\n'), - ); - }); + describe('with .geminiignore', () => { + beforeEach(async () => { + await fsp.writeFile( + path.join(tempRootDir, '.geminiignore'), + ['foo.*', 'ignored/'].join('\n'), + ); + }); - it('should return error if path is ignored by a .geminiignore pattern', async () => { - const ignoredFilePath = path.join(tempRootDir, 'foo.bar'); - await fsp.writeFile(ignoredFilePath, 'content', 'utf-8'); - const params: ReadFileToolParams = { - absolute_path: ignoredFilePath, - }; - const expectedError = `File path '${ignoredFilePath}' is ignored by .geminiignore pattern(s).`; - expect(await tool.execute(params, abortSignal)).toEqual({ - llmContent: `Error: Invalid parameters provided. Reason: ${expectedError}`, - returnDisplay: expectedError, + it('should throw error if path is ignored by a .geminiignore pattern', async () => { + const ignoredFilePath = path.join(tempRootDir, 'foo.bar'); + await fsp.writeFile(ignoredFilePath, 'content', 'utf-8'); + const params: ReadFileToolParams = { + absolute_path: ignoredFilePath, + }; + const expectedError = `File path '${ignoredFilePath}' is ignored by .geminiignore pattern(s).`; + expect(() => tool.build(params)).toThrow(expectedError); }); - }); - it('should return error if path is in an ignored directory', async () => { - const ignoredDirPath = path.join(tempRootDir, 'ignored'); - await fsp.mkdir(ignoredDirPath); - const filePath = path.join(ignoredDirPath, 'somefile.txt'); - await fsp.writeFile(filePath, 'content', 'utf-8'); + it('should throw error if path is in an ignored directory', async () => { + const ignoredDirPath = path.join(tempRootDir, 'ignored'); + await fsp.mkdir(ignoredDirPath); + const filePath = path.join(ignoredDirPath, 'somefile.txt'); + await fsp.writeFile(filePath, 'content', 'utf-8'); - const params: ReadFileToolParams = { - absolute_path: filePath, - }; - const expectedError = `File path '${filePath}' is ignored by .geminiignore pattern(s).`; - expect(await tool.execute(params, abortSignal)).toEqual({ - llmContent: `Error: Invalid parameters provided. Reason: ${expectedError}`, - returnDisplay: expectedError, + const params: ReadFileToolParams = { + absolute_path: filePath, + }; + const expectedError = `File path '${filePath}' is ignored by .geminiignore pattern(s).`; + expect(() => tool.build(params)).toThrow(expectedError); }); }); }); @@ -270,18 +290,16 @@ describe('ReadFileTool', () => { const params: ReadFileToolParams = { absolute_path: path.join(tempRootDir, 'file.txt'), }; - expect(tool.validateToolParams(params)).toBeNull(); + expect(() => tool.build(params)).not.toThrow(); }); it('should reject paths outside workspace root', () => { const params: ReadFileToolParams = { absolute_path: '/etc/passwd', }; - const error = tool.validateToolParams(params); - expect(error).toContain( + expect(() => tool.build(params)).toThrow( 'File path must be within one of the workspace directories', ); - expect(error).toContain(tempRootDir); }); it('should provide clear error message with workspace directories', () => { @@ -289,11 +307,9 @@ describe('ReadFileTool', () => { const params: ReadFileToolParams = { absolute_path: outsidePath, }; - const error = tool.validateToolParams(params); - expect(error).toContain( + expect(() => tool.build(params)).toThrow( 'File path must be within one of the workspace directories', ); - expect(error).toContain(tempRootDir); }); }); }); diff --git a/packages/core/src/tools/read-file.ts b/packages/core/src/tools/read-file.ts index 31282c20..3a05da06 100644 --- a/packages/core/src/tools/read-file.ts +++ b/packages/core/src/tools/read-file.ts @@ -7,7 +7,13 @@ import path from 'path'; import { SchemaValidator } from '../utils/schemaValidator.js'; import { makeRelative, shortenPath } from '../utils/paths.js'; -import { BaseTool, Icon, ToolLocation, ToolResult } from './tools.js'; +import { + BaseDeclarativeTool, + Icon, + ToolInvocation, + ToolLocation, + ToolResult, +} from './tools.js'; import { Type } from '@google/genai'; import { processSingleFileContent, @@ -39,10 +45,72 @@ export interface ReadFileToolParams { limit?: number; } +class ReadFileToolInvocation + implements ToolInvocation<ReadFileToolParams, ToolResult> +{ + constructor( + private config: Config, + public params: ReadFileToolParams, + ) {} + + getDescription(): string { + const relativePath = makeRelative( + this.params.absolute_path, + this.config.getTargetDir(), + ); + return shortenPath(relativePath); + } + + toolLocations(): ToolLocation[] { + return [{ path: this.params.absolute_path, line: this.params.offset }]; + } + + shouldConfirmExecute(): Promise<false> { + return Promise.resolve(false); + } + + async execute(): Promise<ToolResult> { + const result = await processSingleFileContent( + this.params.absolute_path, + this.config.getTargetDir(), + this.params.offset, + this.params.limit, + ); + + if (result.error) { + return { + llmContent: result.error, // The detailed error for LLM + returnDisplay: result.returnDisplay || 'Error reading file', // User-friendly error + }; + } + + const lines = + typeof result.llmContent === 'string' + ? result.llmContent.split('\n').length + : undefined; + const mimetype = getSpecificMimeType(this.params.absolute_path); + recordFileOperationMetric( + this.config, + FileOperation.READ, + lines, + mimetype, + path.extname(this.params.absolute_path), + ); + + return { + llmContent: result.llmContent || '', + returnDisplay: result.returnDisplay || '', + }; + } +} + /** * Implementation of the ReadFile tool logic */ -export class ReadFileTool extends BaseTool<ReadFileToolParams, ToolResult> { +export class ReadFileTool extends BaseDeclarativeTool< + ReadFileToolParams, + ToolResult +> { static readonly Name: string = 'read_file'; constructor(private config: Config) { @@ -75,7 +143,7 @@ export class ReadFileTool extends BaseTool<ReadFileToolParams, ToolResult> { ); } - validateToolParams(params: ReadFileToolParams): string | null { + protected validateToolParams(params: ReadFileToolParams): string | null { const errors = SchemaValidator.validate(this.schema.parameters, params); if (errors) { return errors; @@ -106,67 +174,9 @@ export class ReadFileTool extends BaseTool<ReadFileToolParams, ToolResult> { return null; } - getDescription(params: ReadFileToolParams): string { - if ( - !params || - typeof params.absolute_path !== 'string' || - params.absolute_path.trim() === '' - ) { - return `Path unavailable`; - } - const relativePath = makeRelative( - params.absolute_path, - this.config.getTargetDir(), - ); - return shortenPath(relativePath); - } - - toolLocations(params: ReadFileToolParams): ToolLocation[] { - return [{ path: params.absolute_path, line: params.offset }]; - } - - async execute( + protected createInvocation( params: ReadFileToolParams, - _signal: AbortSignal, - ): Promise<ToolResult> { - const validationError = this.validateToolParams(params); - if (validationError) { - return { - llmContent: `Error: Invalid parameters provided. Reason: ${validationError}`, - returnDisplay: validationError, - }; - } - - const result = await processSingleFileContent( - params.absolute_path, - this.config.getTargetDir(), - params.offset, - params.limit, - ); - - if (result.error) { - return { - llmContent: result.error, // The detailed error for LLM - returnDisplay: result.returnDisplay || 'Error reading file', // User-friendly error - }; - } - - const lines = - typeof result.llmContent === 'string' - ? result.llmContent.split('\n').length - : undefined; - const mimetype = getSpecificMimeType(params.absolute_path); - recordFileOperationMetric( - this.config, - FileOperation.READ, - lines, - mimetype, - path.extname(params.absolute_path), - ); - - return { - llmContent: result.llmContent || '', - returnDisplay: result.returnDisplay || '', - }; + ): ToolInvocation<ReadFileToolParams, ToolResult> { + return new ReadFileToolInvocation(this.config, params); } } diff --git a/packages/core/src/tools/tool-registry.test.ts b/packages/core/src/tools/tool-registry.test.ts index 24b6ca5f..e7c71e14 100644 --- a/packages/core/src/tools/tool-registry.test.ts +++ b/packages/core/src/tools/tool-registry.test.ts @@ -21,7 +21,6 @@ import { sanitizeParameters, } from './tool-registry.js'; import { DiscoveredMCPTool } from './mcp-tool.js'; -import { BaseTool, Icon, ToolResult } from './tools.js'; import { FunctionDeclaration, CallableTool, @@ -32,6 +31,7 @@ import { import { spawn } from 'node:child_process'; import fs from 'node:fs'; +import { MockTool } from '../test-utils/tools.js'; vi.mock('node:fs'); @@ -107,28 +107,6 @@ const createMockCallableTool = ( callTool: vi.fn(), }); -class MockTool extends BaseTool<{ param: string }, ToolResult> { - constructor( - name = 'mock-tool', - displayName = 'A mock tool', - description = 'A mock tool description', - ) { - super(name, displayName, description, Icon.Hammer, { - type: Type.OBJECT, - properties: { - param: { type: Type.STRING }, - }, - required: ['param'], - }); - } - async execute(params: { param: string }): Promise<ToolResult> { - return { - llmContent: `Executed with ${params.param}`, - returnDisplay: `Executed with ${params.param}`, - }; - } -} - const baseConfigParams: ConfigParameters = { cwd: '/tmp', model: 'test-model', diff --git a/packages/core/src/tools/tool-registry.ts b/packages/core/src/tools/tool-registry.ts index e60b8f74..73b427d4 100644 --- a/packages/core/src/tools/tool-registry.ts +++ b/packages/core/src/tools/tool-registry.ts @@ -5,7 +5,7 @@ */ import { FunctionDeclaration, Schema, Type } from '@google/genai'; -import { Tool, ToolResult, BaseTool, Icon } from './tools.js'; +import { AnyDeclarativeTool, Icon, ToolResult, BaseTool } from './tools.js'; import { Config } from '../config/config.js'; import { spawn } from 'node:child_process'; import { StringDecoder } from 'node:string_decoder'; @@ -125,7 +125,7 @@ Signal: Signal number or \`(none)\` if no signal was received. } export class ToolRegistry { - private tools: Map<string, Tool> = new Map(); + private tools: Map<string, AnyDeclarativeTool> = new Map(); private config: Config; constructor(config: Config) { @@ -136,7 +136,7 @@ export class ToolRegistry { * Registers a tool definition. * @param tool - The tool object containing schema and execution logic. */ - registerTool(tool: Tool): void { + registerTool(tool: AnyDeclarativeTool): void { if (this.tools.has(tool.name)) { if (tool instanceof DiscoveredMCPTool) { tool = tool.asFullyQualifiedTool(); @@ -368,7 +368,7 @@ export class ToolRegistry { /** * Returns an array of all registered and discovered tool instances. */ - getAllTools(): Tool[] { + getAllTools(): AnyDeclarativeTool[] { return Array.from(this.tools.values()).sort((a, b) => a.displayName.localeCompare(b.displayName), ); @@ -377,8 +377,8 @@ export class ToolRegistry { /** * Returns an array of tools registered from a specific MCP server. */ - getToolsByServer(serverName: string): Tool[] { - const serverTools: Tool[] = []; + getToolsByServer(serverName: string): AnyDeclarativeTool[] { + const serverTools: AnyDeclarativeTool[] = []; for (const tool of this.tools.values()) { if ((tool as DiscoveredMCPTool)?.serverName === serverName) { serverTools.push(tool); @@ -390,7 +390,7 @@ export class ToolRegistry { /** * Get the definition of a specific tool. */ - getTool(name: string): Tool | undefined { + getTool(name: string): AnyDeclarativeTool | undefined { return this.tools.get(name); } } diff --git a/packages/core/src/tools/tools.ts b/packages/core/src/tools/tools.ts index 3404093f..79e6f010 100644 --- a/packages/core/src/tools/tools.ts +++ b/packages/core/src/tools/tools.ts @@ -9,101 +9,243 @@ import { ToolErrorType } from './tool-error.js'; import { DiffUpdateResult } from '../ide/ideContext.js'; /** - * Interface representing the base Tool functionality + * Represents a validated and ready-to-execute tool call. + * An instance of this is created by a `ToolBuilder`. */ -export interface Tool< - TParams = unknown, - TResult extends ToolResult = ToolResult, +export interface ToolInvocation< + TParams extends object, + TResult extends ToolResult, > { /** - * The internal name of the tool (used for API calls) + * The validated parameters for this specific invocation. + */ + params: TParams; + + /** + * Gets a pre-execution description of the tool operation. + * @returns A markdown string describing what the tool will do. + */ + getDescription(): string; + + /** + * Determines what file system paths the tool will affect. + * @returns A list of such paths. + */ + toolLocations(): ToolLocation[]; + + /** + * Determines if the tool should prompt for confirmation before execution. + * @returns Confirmation details or false if no confirmation is needed. + */ + shouldConfirmExecute( + abortSignal: AbortSignal, + ): Promise<ToolCallConfirmationDetails | false>; + + /** + * Executes the tool with the validated parameters. + * @param signal AbortSignal for tool cancellation. + * @param updateOutput Optional callback to stream output. + * @returns Result of the tool execution. + */ + execute( + signal: AbortSignal, + updateOutput?: (output: string) => void, + ): Promise<TResult>; +} + +/** + * A type alias for a tool invocation where the specific parameter and result types are not known. + */ +export type AnyToolInvocation = ToolInvocation<object, ToolResult>; + +/** + * An adapter that wraps the legacy `Tool` interface to make it compatible + * with the new `ToolInvocation` pattern. + */ +export class LegacyToolInvocation< + TParams extends object, + TResult extends ToolResult, +> implements ToolInvocation<TParams, TResult> +{ + constructor( + private readonly legacyTool: BaseTool<TParams, TResult>, + readonly params: TParams, + ) {} + + getDescription(): string { + return this.legacyTool.getDescription(this.params); + } + + toolLocations(): ToolLocation[] { + return this.legacyTool.toolLocations(this.params); + } + + shouldConfirmExecute( + abortSignal: AbortSignal, + ): Promise<ToolCallConfirmationDetails | false> { + return this.legacyTool.shouldConfirmExecute(this.params, abortSignal); + } + + execute( + signal: AbortSignal, + updateOutput?: (output: string) => void, + ): Promise<TResult> { + return this.legacyTool.execute(this.params, signal, updateOutput); + } +} + +/** + * Interface for a tool builder that validates parameters and creates invocations. + */ +export interface ToolBuilder< + TParams extends object, + TResult extends ToolResult, +> { + /** + * The internal name of the tool (used for API calls). */ name: string; /** - * The user-friendly display name of the tool + * The user-friendly display name of the tool. */ displayName: string; /** - * Description of what the tool does + * Description of what the tool does. */ description: string; /** - * The icon to display when interacting via ACP + * The icon to display when interacting via ACP. */ icon: Icon; /** - * Function declaration schema from @google/genai + * Function declaration schema from @google/genai. */ schema: FunctionDeclaration; /** - * Whether the tool's output should be rendered as markdown + * Whether the tool's output should be rendered as markdown. */ isOutputMarkdown: boolean; /** - * Whether the tool supports live (streaming) output + * Whether the tool supports live (streaming) output. */ canUpdateOutput: boolean; /** - * Validates the parameters for the tool - * Should be called from both `shouldConfirmExecute` and `execute` - * `shouldConfirmExecute` should return false immediately if invalid - * @param params Parameters to validate - * @returns An error message string if invalid, null otherwise + * Validates raw parameters and builds a ready-to-execute invocation. + * @param params The raw, untrusted parameters from the model. + * @returns A valid `ToolInvocation` if successful. Throws an error if validation fails. */ - validateToolParams(params: TParams): string | null; + build(params: TParams): ToolInvocation<TParams, TResult>; +} - /** - * Gets a pre-execution description of the tool operation - * @param params Parameters for the tool execution - * @returns A markdown string describing what the tool will do - * Optional for backward compatibility - */ - getDescription(params: TParams): string; +/** + * New base class for tools that separates validation from execution. + * New tools should extend this class. + */ +export abstract class DeclarativeTool< + TParams extends object, + TResult extends ToolResult, +> implements ToolBuilder<TParams, TResult> +{ + constructor( + readonly name: string, + readonly displayName: string, + readonly description: string, + readonly icon: Icon, + readonly parameterSchema: Schema, + readonly isOutputMarkdown: boolean = true, + readonly canUpdateOutput: boolean = false, + ) {} + + get schema(): FunctionDeclaration { + return { + name: this.name, + description: this.description, + parameters: this.parameterSchema, + }; + } /** - * Determines what file system paths the tool will affect - * @param params Parameters for the tool execution - * @returns A list of such paths + * Validates the raw tool parameters. + * Subclasses should override this to add custom validation logic + * beyond the JSON schema check. + * @param params The raw parameters from the model. + * @returns An error message string if invalid, null otherwise. */ - toolLocations(params: TParams): ToolLocation[]; + protected validateToolParams(_params: TParams): string | null { + // Base implementation can be extended by subclasses. + return null; + } /** - * Determines if the tool should prompt for confirmation before execution - * @param params Parameters for the tool execution - * @returns Whether execute should be confirmed. + * The core of the new pattern. It validates parameters and, if successful, + * returns a `ToolInvocation` object that encapsulates the logic for the + * specific, validated call. + * @param params The raw, untrusted parameters from the model. + * @returns A `ToolInvocation` instance. */ - shouldConfirmExecute( - params: TParams, - abortSignal: AbortSignal, - ): Promise<ToolCallConfirmationDetails | false>; + abstract build(params: TParams): ToolInvocation<TParams, TResult>; /** - * Executes the tool with the given parameters - * @param params Parameters for the tool execution - * @returns Result of the tool execution + * A convenience method that builds and executes the tool in one step. + * Throws an error if validation fails. + * @param params The raw, untrusted parameters from the model. + * @param signal AbortSignal for tool cancellation. + * @param updateOutput Optional callback to stream output. + * @returns The result of the tool execution. */ - execute( + async buildAndExecute( params: TParams, signal: AbortSignal, updateOutput?: (output: string) => void, - ): Promise<TResult>; + ): Promise<TResult> { + const invocation = this.build(params); + return invocation.execute(signal, updateOutput); + } +} + +/** + * New base class for declarative tools that separates validation from execution. + * New tools should extend this class, which provides a `build` method that + * validates parameters before deferring to a `createInvocation` method for + * the final `ToolInvocation` object instantiation. + */ +export abstract class BaseDeclarativeTool< + TParams extends object, + TResult extends ToolResult, +> extends DeclarativeTool<TParams, TResult> { + build(params: TParams): ToolInvocation<TParams, TResult> { + const validationError = this.validateToolParams(params); + if (validationError) { + throw new Error(validationError); + } + return this.createInvocation(params); + } + + protected abstract createInvocation( + params: TParams, + ): ToolInvocation<TParams, TResult>; } /** + * A type alias for a declarative tool where the specific parameter and result types are not known. + */ +export type AnyDeclarativeTool = DeclarativeTool<object, ToolResult>; + +/** * Base implementation for tools with common functionality + * @deprecated Use `DeclarativeTool` for new tools. */ export abstract class BaseTool< - TParams = unknown, + TParams extends object, TResult extends ToolResult = ToolResult, -> implements Tool<TParams, TResult> -{ +> extends DeclarativeTool<TParams, TResult> { /** * Creates a new instance of BaseTool * @param name Internal name of the tool (used for API calls) @@ -121,17 +263,24 @@ export abstract class BaseTool< readonly parameterSchema: Schema, readonly isOutputMarkdown: boolean = true, readonly canUpdateOutput: boolean = false, - ) {} + ) { + super( + name, + displayName, + description, + icon, + parameterSchema, + isOutputMarkdown, + canUpdateOutput, + ); + } - /** - * Function declaration schema computed from name, description, and parameterSchema - */ - get schema(): FunctionDeclaration { - return { - name: this.name, - description: this.description, - parameters: this.parameterSchema, - }; + build(params: TParams): ToolInvocation<TParams, TResult> { + const validationError = this.validateToolParams(params); + if (validationError) { + throw new Error(validationError); + } + return new LegacyToolInvocation(this, params); } /** diff --git a/packages/core/src/tools/write-file.ts b/packages/core/src/tools/write-file.ts index 32ecc068..9e7e3813 100644 --- a/packages/core/src/tools/write-file.ts +++ b/packages/core/src/tools/write-file.ts @@ -26,7 +26,7 @@ import { ensureCorrectFileContent, } from '../utils/editCorrector.js'; import { DEFAULT_DIFF_OPTIONS } from './diffOptions.js'; -import { ModifiableTool, ModifyContext } from './modifiable-tool.js'; +import { ModifiableDeclarativeTool, ModifyContext } from './modifiable-tool.js'; import { getSpecificMimeType } from '../utils/fileUtils.js'; import { recordFileOperationMetric, @@ -66,7 +66,7 @@ interface GetCorrectedFileContentResult { */ export class WriteFileTool extends BaseTool<WriteFileToolParams, ToolResult> - implements ModifiableTool<WriteFileToolParams> + implements ModifiableDeclarativeTool<WriteFileToolParams> { static readonly Name: string = 'write_file'; |
