diff options
Diffstat (limited to 'packages/cli/src')
| -rw-r--r-- | packages/cli/src/acp/acpPeer.ts | 149 | ||||
| -rw-r--r-- | packages/cli/src/ui/hooks/atCommandProcessor.ts | 13 | ||||
| -rw-r--r-- | packages/cli/src/ui/hooks/useGeminiStream.test.tsx | 56 | ||||
| -rw-r--r-- | packages/cli/src/ui/hooks/useReactToolScheduler.ts | 27 | ||||
| -rw-r--r-- | packages/cli/src/ui/hooks/useToolScheduler.test.ts | 180 |
5 files changed, 242 insertions, 183 deletions
diff --git a/packages/cli/src/acp/acpPeer.ts b/packages/cli/src/acp/acpPeer.ts index 90952b7f..40d8753f 100644 --- a/packages/cli/src/acp/acpPeer.ts +++ b/packages/cli/src/acp/acpPeer.ts @@ -239,65 +239,62 @@ class GeminiAgent implements Agent { ); } - let toolCallId; - const confirmationDetails = await tool.shouldConfirmExecute( - args, - abortSignal, - ); - if (confirmationDetails) { - let content: acp.ToolCallContent | null = null; - if (confirmationDetails.type === 'edit') { - content = { - type: 'diff', - path: confirmationDetails.fileName, - oldText: confirmationDetails.originalContent, - newText: confirmationDetails.newContent, - }; - } + let toolCallId: number | undefined = undefined; + try { + const invocation = tool.build(args); + const confirmationDetails = + await invocation.shouldConfirmExecute(abortSignal); + if (confirmationDetails) { + let content: acp.ToolCallContent | null = null; + if (confirmationDetails.type === 'edit') { + content = { + type: 'diff', + path: confirmationDetails.fileName, + oldText: confirmationDetails.originalContent, + newText: confirmationDetails.newContent, + }; + } - const result = await this.client.requestToolCallConfirmation({ - label: tool.getDescription(args), - icon: tool.icon, - content, - confirmation: toAcpToolCallConfirmation(confirmationDetails), - locations: tool.toolLocations(args), - }); + const result = await this.client.requestToolCallConfirmation({ + label: invocation.getDescription(), + icon: tool.icon, + content, + confirmation: toAcpToolCallConfirmation(confirmationDetails), + locations: invocation.toolLocations(), + }); - await confirmationDetails.onConfirm(toToolCallOutcome(result.outcome)); - switch (result.outcome) { - case 'reject': - return errorResponse( - new Error(`Tool "${fc.name}" not allowed to run by the user.`), - ); + await confirmationDetails.onConfirm(toToolCallOutcome(result.outcome)); + switch (result.outcome) { + case 'reject': + return errorResponse( + new Error(`Tool "${fc.name}" not allowed to run by the user.`), + ); - case 'cancel': - return errorResponse( - new Error(`Tool "${fc.name}" was canceled by the user.`), - ); - case 'allow': - case 'alwaysAllow': - case 'alwaysAllowMcpServer': - case 'alwaysAllowTool': - break; - default: { - const resultOutcome: never = result.outcome; - throw new Error(`Unexpected: ${resultOutcome}`); + case 'cancel': + return errorResponse( + new Error(`Tool "${fc.name}" was canceled by the user.`), + ); + case 'allow': + case 'alwaysAllow': + case 'alwaysAllowMcpServer': + case 'alwaysAllowTool': + break; + default: { + const resultOutcome: never = result.outcome; + throw new Error(`Unexpected: ${resultOutcome}`); + } } + toolCallId = result.id; + } else { + const result = await this.client.pushToolCall({ + icon: tool.icon, + label: invocation.getDescription(), + locations: invocation.toolLocations(), + }); + toolCallId = result.id; } - toolCallId = result.id; - } else { - const result = await this.client.pushToolCall({ - icon: tool.icon, - label: tool.getDescription(args), - locations: tool.toolLocations(args), - }); - - toolCallId = result.id; - } - - try { - const toolResult: ToolResult = await tool.execute(args, abortSignal); + const toolResult: ToolResult = await invocation.execute(abortSignal); const toolCallContent = toToolCallContent(toolResult); await this.client.updateToolCall({ @@ -320,12 +317,13 @@ class GeminiAgent implements Agent { return convertToFunctionResponse(fc.name, callId, toolResult.llmContent); } catch (e) { const error = e instanceof Error ? e : new Error(String(e)); - await this.client.updateToolCall({ - toolCallId, - status: 'error', - content: { type: 'markdown', markdown: error.message }, - }); - + if (toolCallId) { + await this.client.updateToolCall({ + toolCallId, + status: 'error', + content: { type: 'markdown', markdown: error.message }, + }); + } return errorResponse(error); } } @@ -408,7 +406,7 @@ class GeminiAgent implements Agent { `Path ${pathName} not found directly, attempting glob search.`, ); try { - const globResult = await globTool.execute( + const globResult = await globTool.buildAndExecute( { pattern: `**/*${pathName}*`, path: this.config.getTargetDir(), @@ -530,12 +528,15 @@ class GeminiAgent implements Agent { respectGitIgnore, // Use configuration setting }; - const toolCall = await this.client.pushToolCall({ - icon: readManyFilesTool.icon, - label: readManyFilesTool.getDescription(toolArgs), - }); + let toolCallId: number | undefined = undefined; try { - const result = await readManyFilesTool.execute(toolArgs, abortSignal); + const invocation = readManyFilesTool.build(toolArgs); + const toolCall = await this.client.pushToolCall({ + icon: readManyFilesTool.icon, + label: invocation.getDescription(), + }); + toolCallId = toolCall.id; + const result = await invocation.execute(abortSignal); const content = toToolCallContent(result) || { type: 'markdown', markdown: `Successfully read: ${contentLabelsForDisplay.join(', ')}`, @@ -578,14 +579,16 @@ class GeminiAgent implements Agent { return processedQueryParts; } catch (error: unknown) { - await this.client.updateToolCall({ - toolCallId: toolCall.id, - status: 'error', - content: { - type: 'markdown', - markdown: `Error reading files (${contentLabelsForDisplay.join(', ')}): ${getErrorMessage(error)}`, - }, - }); + if (toolCallId) { + await this.client.updateToolCall({ + toolCallId, + status: 'error', + content: { + type: 'markdown', + markdown: `Error reading files (${contentLabelsForDisplay.join(', ')}): ${getErrorMessage(error)}`, + }, + }); + } throw error; } } diff --git a/packages/cli/src/ui/hooks/atCommandProcessor.ts b/packages/cli/src/ui/hooks/atCommandProcessor.ts index 165b7b30..cef2f811 100644 --- a/packages/cli/src/ui/hooks/atCommandProcessor.ts +++ b/packages/cli/src/ui/hooks/atCommandProcessor.ts @@ -8,6 +8,7 @@ import * as fs from 'fs/promises'; import * as path from 'path'; import { PartListUnion, PartUnion } from '@google/genai'; import { + AnyToolInvocation, Config, getErrorMessage, isNodeError, @@ -254,7 +255,7 @@ export async function handleAtCommand({ `Path ${pathName} not found directly, attempting glob search.`, ); try { - const globResult = await globTool.execute( + const globResult = await globTool.buildAndExecute( { pattern: `**/*${pathName}*`, path: dir, @@ -411,12 +412,14 @@ export async function handleAtCommand({ }; let toolCallDisplay: IndividualToolCallDisplay; + let invocation: AnyToolInvocation | undefined = undefined; try { - const result = await readManyFilesTool.execute(toolArgs, signal); + invocation = readManyFilesTool.build(toolArgs); + const result = await invocation.execute(signal); toolCallDisplay = { callId: `client-read-${userMessageTimestamp}`, name: readManyFilesTool.displayName, - description: readManyFilesTool.getDescription(toolArgs), + description: invocation.getDescription(), status: ToolCallStatus.Success, resultDisplay: result.returnDisplay || @@ -466,7 +469,9 @@ export async function handleAtCommand({ toolCallDisplay = { callId: `client-read-${userMessageTimestamp}`, name: readManyFilesTool.displayName, - description: readManyFilesTool.getDescription(toolArgs), + description: + invocation?.getDescription() ?? + 'Error attempting to execute tool to read files', status: ToolCallStatus.Error, resultDisplay: `Error reading files (${contentLabelsForDisplay.join(', ')}): ${getErrorMessage(error)}`, confirmationDetails: undefined, diff --git a/packages/cli/src/ui/hooks/useGeminiStream.test.tsx b/packages/cli/src/ui/hooks/useGeminiStream.test.tsx index 062c1687..dd2428bb 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.test.tsx +++ b/packages/cli/src/ui/hooks/useGeminiStream.test.tsx @@ -21,6 +21,7 @@ import { EditorType, AuthType, GeminiEventType as ServerGeminiEventType, + AnyToolInvocation, } from '@google/gemini-cli-core'; import { Part, PartListUnion } from '@google/genai'; import { UseHistoryManagerReturn } from './useHistoryManager.js'; @@ -452,9 +453,13 @@ describe('useGeminiStream', () => { }, tool: { name: 'tool1', + displayName: 'tool1', description: 'desc1', - getDescription: vi.fn(), + build: vi.fn(), } as any, + invocation: { + getDescription: () => `Mock description`, + } as unknown as AnyToolInvocation, startTime: Date.now(), endTime: Date.now(), } as TrackedCompletedToolCall, @@ -469,9 +474,13 @@ describe('useGeminiStream', () => { responseSubmittedToGemini: false, tool: { name: 'tool2', + displayName: 'tool2', description: 'desc2', - getDescription: vi.fn(), + build: vi.fn(), } as any, + invocation: { + getDescription: () => `Mock description`, + } as unknown as AnyToolInvocation, startTime: Date.now(), liveOutput: '...', } as TrackedExecutingToolCall, @@ -506,6 +515,12 @@ describe('useGeminiStream', () => { status: 'success', responseSubmittedToGemini: false, response: { callId: 'call1', responseParts: toolCall1ResponseParts }, + tool: { + displayName: 'MockTool', + }, + invocation: { + getDescription: () => `Mock description`, + } as unknown as AnyToolInvocation, } as TrackedCompletedToolCall, { request: { @@ -584,6 +599,12 @@ describe('useGeminiStream', () => { status: 'cancelled', response: { callId: '1', responseParts: [{ text: 'cancelled' }] }, responseSubmittedToGemini: false, + tool: { + displayName: 'mock tool', + }, + invocation: { + getDescription: () => `Mock description`, + } as unknown as AnyToolInvocation, } as TrackedCancelledToolCall, ]; const client = new MockedGeminiClientClass(mockConfig); @@ -644,9 +665,13 @@ describe('useGeminiStream', () => { }, tool: { name: 'toolA', + displayName: 'toolA', description: 'descA', - getDescription: vi.fn(), + build: vi.fn(), } as any, + invocation: { + getDescription: () => `Mock description`, + } as unknown as AnyToolInvocation, status: 'cancelled', response: { callId: 'cancel-1', @@ -668,9 +693,13 @@ describe('useGeminiStream', () => { }, tool: { name: 'toolB', + displayName: 'toolB', description: 'descB', - getDescription: vi.fn(), + build: vi.fn(), } as any, + invocation: { + getDescription: () => `Mock description`, + } as unknown as AnyToolInvocation, status: 'cancelled', response: { callId: 'cancel-2', @@ -760,9 +789,13 @@ describe('useGeminiStream', () => { responseSubmittedToGemini: false, tool: { name: 'tool1', + displayName: 'tool1', description: 'desc', - getDescription: vi.fn(), + build: vi.fn(), } as any, + invocation: { + getDescription: () => `Mock description`, + } as unknown as AnyToolInvocation, startTime: Date.now(), } as TrackedExecutingToolCall, ]; @@ -980,8 +1013,13 @@ describe('useGeminiStream', () => { tool: { name: 'tool1', description: 'desc1', - getDescription: vi.fn(), + build: vi.fn().mockImplementation((_) => ({ + getDescription: () => `Mock description`, + })), } as any, + invocation: { + getDescription: () => `Mock description`, + }, startTime: Date.now(), liveOutput: '...', } as TrackedExecutingToolCall, @@ -1131,9 +1169,13 @@ describe('useGeminiStream', () => { }, tool: { name: 'save_memory', + displayName: 'save_memory', description: 'Saves memory', - getDescription: vi.fn(), + build: vi.fn(), } as any, + invocation: { + getDescription: () => `Mock description`, + } as unknown as AnyToolInvocation, }; // Capture the onComplete callback diff --git a/packages/cli/src/ui/hooks/useReactToolScheduler.ts b/packages/cli/src/ui/hooks/useReactToolScheduler.ts index 01993650..c6b802fc 100644 --- a/packages/cli/src/ui/hooks/useReactToolScheduler.ts +++ b/packages/cli/src/ui/hooks/useReactToolScheduler.ts @@ -17,7 +17,6 @@ import { OutputUpdateHandler, AllToolCallsCompleteHandler, ToolCallsUpdateHandler, - Tool, ToolCall, Status as CoreStatus, EditorType, @@ -216,23 +215,20 @@ export function mapToDisplay( const toolDisplays = toolCalls.map( (trackedCall): IndividualToolCallDisplay => { - let displayName = trackedCall.request.name; - let description = ''; + let displayName: string; + let description: string; let renderOutputAsMarkdown = false; - const currentToolInstance = - 'tool' in trackedCall && trackedCall.tool - ? (trackedCall as { tool: Tool }).tool - : undefined; - - if (currentToolInstance) { - displayName = currentToolInstance.displayName; - description = currentToolInstance.getDescription( - trackedCall.request.args, - ); - renderOutputAsMarkdown = currentToolInstance.isOutputMarkdown; - } else if ('request' in trackedCall && 'args' in trackedCall.request) { + if (trackedCall.status === 'error') { + displayName = + trackedCall.tool === undefined + ? trackedCall.request.name + : trackedCall.tool.displayName; description = JSON.stringify(trackedCall.request.args); + } else { + displayName = trackedCall.tool.displayName; + description = trackedCall.invocation.getDescription(); + renderOutputAsMarkdown = trackedCall.tool.isOutputMarkdown; } const baseDisplayProperties: Omit< @@ -256,7 +252,6 @@ export function mapToDisplay( case 'error': return { ...baseDisplayProperties, - name: currentToolInstance?.displayName ?? trackedCall.request.name, status: mapCoreStatusToDisplayStatus(trackedCall.status), resultDisplay: trackedCall.response.resultDisplay, confirmationDetails: undefined, diff --git a/packages/cli/src/ui/hooks/useToolScheduler.test.ts b/packages/cli/src/ui/hooks/useToolScheduler.test.ts index 5395d18a..ee5251d3 100644 --- a/packages/cli/src/ui/hooks/useToolScheduler.test.ts +++ b/packages/cli/src/ui/hooks/useToolScheduler.test.ts @@ -15,7 +15,6 @@ import { PartUnion, FunctionResponse } from '@google/genai'; import { Config, ToolCallRequestInfo, - Tool, ToolRegistry, ToolResult, ToolCallConfirmationDetails, @@ -25,6 +24,9 @@ import { Status as ToolCallStatusType, ApprovalMode, Icon, + BaseTool, + AnyDeclarativeTool, + AnyToolInvocation, } from '@google/gemini-cli-core'; import { HistoryItemWithoutId, @@ -53,46 +55,55 @@ const mockConfig = { getDebugMode: () => false, }; -const mockTool: Tool = { - name: 'mockTool', - displayName: 'Mock Tool', - description: 'A mock tool for testing', - icon: Icon.Hammer, - toolLocations: vi.fn(), - isOutputMarkdown: false, - canUpdateOutput: false, - schema: {}, - validateToolParams: vi.fn(), - execute: vi.fn(), - shouldConfirmExecute: vi.fn(), - getDescription: vi.fn((args) => `Description for ${JSON.stringify(args)}`), -}; +class MockTool extends BaseTool<object, ToolResult> { + constructor( + name: string, + displayName: string, + canUpdateOutput = false, + shouldConfirm = false, + isOutputMarkdown = false, + ) { + super( + name, + displayName, + 'A mock tool for testing', + Icon.Hammer, + {}, + isOutputMarkdown, + canUpdateOutput, + ); + if (shouldConfirm) { + this.shouldConfirmExecute = vi.fn( + async (): Promise<ToolCallConfirmationDetails | false> => ({ + type: 'edit', + title: 'Mock Tool Requires Confirmation', + onConfirm: mockOnUserConfirmForToolConfirmation, + fileName: 'mockToolRequiresConfirmation.ts', + fileDiff: 'Mock tool requires confirmation', + originalContent: 'Original content', + newContent: 'New content', + }), + ); + } + } -const mockToolWithLiveOutput: Tool = { - ...mockTool, - name: 'mockToolWithLiveOutput', - displayName: 'Mock Tool With Live Output', - canUpdateOutput: true, -}; + execute = vi.fn(); + shouldConfirmExecute = vi.fn(); +} +const mockTool = new MockTool('mockTool', 'Mock Tool'); +const mockToolWithLiveOutput = new MockTool( + 'mockToolWithLiveOutput', + 'Mock Tool With Live Output', + true, +); let mockOnUserConfirmForToolConfirmation: Mock; - -const mockToolRequiresConfirmation: Tool = { - ...mockTool, - name: 'mockToolRequiresConfirmation', - displayName: 'Mock Tool Requires Confirmation', - shouldConfirmExecute: vi.fn( - async (): Promise<ToolCallConfirmationDetails | false> => ({ - type: 'edit', - title: 'Mock Tool Requires Confirmation', - onConfirm: mockOnUserConfirmForToolConfirmation, - fileName: 'mockToolRequiresConfirmation.ts', - fileDiff: 'Mock tool requires confirmation', - originalContent: 'Original content', - newContent: 'New content', - }), - ), -}; +const mockToolRequiresConfirmation = new MockTool( + 'mockToolRequiresConfirmation', + 'Mock Tool Requires Confirmation', + false, + true, +); describe('useReactToolScheduler in YOLO Mode', () => { let onComplete: Mock; @@ -646,28 +657,21 @@ describe('useReactToolScheduler', () => { }); it('should schedule and execute multiple tool calls', async () => { - const tool1 = { - ...mockTool, - name: 'tool1', - displayName: 'Tool 1', - execute: vi.fn().mockResolvedValue({ - llmContent: 'Output 1', - returnDisplay: 'Display 1', - summary: 'Summary 1', - } as ToolResult), - shouldConfirmExecute: vi.fn().mockResolvedValue(null), - }; - const tool2 = { - ...mockTool, - name: 'tool2', - displayName: 'Tool 2', - execute: vi.fn().mockResolvedValue({ - llmContent: 'Output 2', - returnDisplay: 'Display 2', - summary: 'Summary 2', - } as ToolResult), - shouldConfirmExecute: vi.fn().mockResolvedValue(null), - }; + const tool1 = new MockTool('tool1', 'Tool 1'); + tool1.execute.mockResolvedValue({ + llmContent: 'Output 1', + returnDisplay: 'Display 1', + summary: 'Summary 1', + } as ToolResult); + tool1.shouldConfirmExecute.mockResolvedValue(null); + + const tool2 = new MockTool('tool2', 'Tool 2'); + tool2.execute.mockResolvedValue({ + llmContent: 'Output 2', + returnDisplay: 'Display 2', + summary: 'Summary 2', + } as ToolResult); + tool2.shouldConfirmExecute.mockResolvedValue(null); mockToolRegistry.getTool.mockImplementation((name) => { if (name === 'tool1') return tool1; @@ -805,20 +809,7 @@ describe('mapToDisplay', () => { args: { foo: 'bar' }, }; - const baseTool: Tool = { - name: 'testTool', - displayName: 'Test Tool Display', - description: 'Test Description', - isOutputMarkdown: false, - canUpdateOutput: false, - schema: {}, - icon: Icon.Hammer, - toolLocations: vi.fn(), - validateToolParams: vi.fn(), - execute: vi.fn(), - shouldConfirmExecute: vi.fn(), - getDescription: vi.fn((args) => `Desc: ${JSON.stringify(args)}`), - }; + const baseTool = new MockTool('testTool', 'Test Tool Display'); const baseResponse: ToolCallResponseInfo = { callId: 'testCallId', @@ -840,13 +831,15 @@ describe('mapToDisplay', () => { // This helps ensure that tool and confirmationDetails are only accessed when they are expected to exist. type MapToDisplayExtraProps = | { - tool?: Tool; + tool?: AnyDeclarativeTool; + invocation?: AnyToolInvocation; liveOutput?: string; response?: ToolCallResponseInfo; confirmationDetails?: ToolCallConfirmationDetails; } | { - tool: Tool; + tool: AnyDeclarativeTool; + invocation?: AnyToolInvocation; response?: ToolCallResponseInfo; confirmationDetails?: ToolCallConfirmationDetails; } @@ -857,10 +850,12 @@ describe('mapToDisplay', () => { } | { confirmationDetails: ToolCallConfirmationDetails; - tool?: Tool; + tool?: AnyDeclarativeTool; + invocation?: AnyToolInvocation; response?: ToolCallResponseInfo; }; + const baseInvocation = baseTool.build(baseRequest.args); const testCases: Array<{ name: string; status: ToolCallStatusType; @@ -873,7 +868,7 @@ describe('mapToDisplay', () => { { name: 'validating', status: 'validating', - extraProps: { tool: baseTool }, + extraProps: { tool: baseTool, invocation: baseInvocation }, expectedStatus: ToolCallStatus.Executing, expectedName: baseTool.displayName, expectedDescription: baseTool.getDescription(baseRequest.args), @@ -883,6 +878,7 @@ describe('mapToDisplay', () => { status: 'awaiting_approval', extraProps: { tool: baseTool, + invocation: baseInvocation, confirmationDetails: { onConfirm: vi.fn(), type: 'edit', @@ -903,7 +899,7 @@ describe('mapToDisplay', () => { { name: 'scheduled', status: 'scheduled', - extraProps: { tool: baseTool }, + extraProps: { tool: baseTool, invocation: baseInvocation }, expectedStatus: ToolCallStatus.Pending, expectedName: baseTool.displayName, expectedDescription: baseTool.getDescription(baseRequest.args), @@ -911,7 +907,7 @@ describe('mapToDisplay', () => { { name: 'executing no live output', status: 'executing', - extraProps: { tool: baseTool }, + extraProps: { tool: baseTool, invocation: baseInvocation }, expectedStatus: ToolCallStatus.Executing, expectedName: baseTool.displayName, expectedDescription: baseTool.getDescription(baseRequest.args), @@ -919,7 +915,11 @@ describe('mapToDisplay', () => { { name: 'executing with live output', status: 'executing', - extraProps: { tool: baseTool, liveOutput: 'Live test output' }, + extraProps: { + tool: baseTool, + invocation: baseInvocation, + liveOutput: 'Live test output', + }, expectedStatus: ToolCallStatus.Executing, expectedResultDisplay: 'Live test output', expectedName: baseTool.displayName, @@ -928,7 +928,11 @@ describe('mapToDisplay', () => { { name: 'success', status: 'success', - extraProps: { tool: baseTool, response: baseResponse }, + extraProps: { + tool: baseTool, + invocation: baseInvocation, + response: baseResponse, + }, expectedStatus: ToolCallStatus.Success, expectedResultDisplay: baseResponse.resultDisplay as any, expectedName: baseTool.displayName, @@ -970,6 +974,7 @@ describe('mapToDisplay', () => { status: 'cancelled', extraProps: { tool: baseTool, + invocation: baseInvocation, response: { ...baseResponse, resultDisplay: 'Cancelled display', @@ -1030,12 +1035,21 @@ describe('mapToDisplay', () => { request: { ...baseRequest, callId: 'call1' }, status: 'success', tool: baseTool, + invocation: baseTool.build(baseRequest.args), response: { ...baseResponse, callId: 'call1' }, } as ToolCall; + const toolForCall2 = new MockTool( + baseTool.name, + baseTool.displayName, + false, + false, + true, + ); const toolCall2: ToolCall = { request: { ...baseRequest, callId: 'call2' }, status: 'executing', - tool: { ...baseTool, isOutputMarkdown: true }, + tool: toolForCall2, + invocation: toolForCall2.build(baseRequest.args), liveOutput: 'markdown output', } as ToolCall; |
