diff options
| author | Agus Zubiaga <[email protected]> | 2025-08-13 12:58:26 -0300 |
|---|---|---|
| committer | GitHub <[email protected]> | 2025-08-13 15:58:26 +0000 |
| commit | d3fda9dafb3921c9edd5cf4fc166dedecd91d84f (patch) | |
| tree | 355aa9b16d9a296515e7c0ed91aa94969c41bc70 /packages/cli | |
| parent | 150103e5ddaa3d6790f7d64e86b0e0deed576ad8 (diff) | |
Zed integration schema upgrade (#5536)
Co-authored-by: Conrad Irwin <[email protected]>
Co-authored-by: Ben Brandt <[email protected]>
Diffstat (limited to 'packages/cli')
| -rw-r--r-- | packages/cli/src/acp/acp.ts | 464 | ||||
| -rw-r--r-- | packages/cli/src/config/config.ts | 13 | ||||
| -rw-r--r-- | packages/cli/src/gemini.tsx | 6 | ||||
| -rw-r--r-- | packages/cli/src/ui/hooks/useToolScheduler.test.ts | 4 | ||||
| -rw-r--r-- | packages/cli/src/zed-integration/acp.ts | 366 | ||||
| -rw-r--r-- | packages/cli/src/zed-integration/schema.ts | 457 | ||||
| -rw-r--r-- | packages/cli/src/zed-integration/zedIntegration.ts (renamed from packages/cli/src/acp/acpPeer.ts) | 610 |
7 files changed, 1226 insertions, 694 deletions
diff --git a/packages/cli/src/acp/acp.ts b/packages/cli/src/acp/acp.ts deleted file mode 100644 index 1fbdf7a8..00000000 --- a/packages/cli/src/acp/acp.ts +++ /dev/null @@ -1,464 +0,0 @@ -/** - * @license - * Copyright 2025 Google LLC - * SPDX-License-Identifier: Apache-2.0 - */ - -/* ACP defines a schema for a simple (experimental) JSON-RPC protocol that allows GUI applications to interact with agents. */ - -import { Icon } from '@google/gemini-cli-core'; -import { WritableStream, ReadableStream } from 'node:stream/web'; - -export class ClientConnection implements Client { - #connection: Connection<Agent>; - - constructor( - agent: (client: Client) => Agent, - input: WritableStream<Uint8Array>, - output: ReadableStream<Uint8Array>, - ) { - this.#connection = new Connection(agent(this), input, output); - } - - /** - * Streams part of an assistant response to the client - */ - async streamAssistantMessageChunk( - params: StreamAssistantMessageChunkParams, - ): Promise<void> { - await this.#connection.sendRequest('streamAssistantMessageChunk', params); - } - - /** - * Request confirmation before running a tool - * - * When allowed, the client returns a [`ToolCallId`] which can be used - * to update the tool call's `status` and `content` as it runs. - */ - requestToolCallConfirmation( - params: RequestToolCallConfirmationParams, - ): Promise<RequestToolCallConfirmationResponse> { - return this.#connection.sendRequest('requestToolCallConfirmation', params); - } - - /** - * pushToolCall allows the agent to start a tool call - * when it does not need to request permission to do so. - * - * The returned id can be used to update the UI for the tool - * call as needed. - */ - pushToolCall(params: PushToolCallParams): Promise<PushToolCallResponse> { - return this.#connection.sendRequest('pushToolCall', params); - } - - /** - * updateToolCall allows the agent to update the content and status of the tool call. - * - * The new content replaces what is currently displayed in the UI. - * - * The [`ToolCallId`] is included in the response of - * `pushToolCall` or `requestToolCallConfirmation` respectively. - */ - async updateToolCall(params: UpdateToolCallParams): Promise<void> { - await this.#connection.sendRequest('updateToolCall', params); - } -} - -type AnyMessage = AnyRequest | AnyResponse; - -type AnyRequest = { - id: number; - method: string; - params?: unknown; -}; - -type AnyResponse = { jsonrpc: '2.0'; id: number } & Result<unknown>; - -type Result<T> = - | { - result: T; - } - | { - error: ErrorResponse; - }; - -type ErrorResponse = { - code: number; - message: string; - data?: { details?: string }; -}; - -type PendingResponse = { - resolve: (response: unknown) => void; - reject: (error: ErrorResponse) => void; -}; - -class Connection<D> { - #pendingResponses: Map<number, PendingResponse> = new Map(); - #nextRequestId: number = 0; - #delegate: D; - #peerInput: WritableStream<Uint8Array>; - #writeQueue: Promise<void> = Promise.resolve(); - #textEncoder: TextEncoder; - - constructor( - delegate: D, - peerInput: WritableStream<Uint8Array>, - peerOutput: ReadableStream<Uint8Array>, - ) { - this.#peerInput = peerInput; - this.#textEncoder = new TextEncoder(); - - this.#delegate = delegate; - this.#receive(peerOutput); - } - - async #receive(output: ReadableStream<Uint8Array>) { - let content = ''; - const decoder = new TextDecoder(); - for await (const chunk of output) { - content += decoder.decode(chunk, { stream: true }); - const lines = content.split('\n'); - content = lines.pop() || ''; - - for (const line of lines) { - const trimmedLine = line.trim(); - - if (trimmedLine) { - const message = JSON.parse(trimmedLine); - this.#processMessage(message); - } - } - } - } - - async #processMessage(message: AnyMessage) { - if ('method' in message) { - const response = await this.#tryCallDelegateMethod( - message.method, - message.params, - ); - - await this.#sendMessage({ - jsonrpc: '2.0', - id: message.id, - ...response, - }); - } else { - this.#handleResponse(message); - } - } - - async #tryCallDelegateMethod( - method: string, - params?: unknown, - ): Promise<Result<unknown>> { - const methodName = method as keyof D; - if (typeof this.#delegate[methodName] !== 'function') { - return RequestError.methodNotFound(method).toResult(); - } - - try { - const result = await this.#delegate[methodName](params); - return { result: result ?? null }; - } catch (error: unknown) { - if (error instanceof RequestError) { - return error.toResult(); - } - - let details; - - if (error instanceof Error) { - details = error.message; - } else if ( - typeof error === 'object' && - error != null && - 'message' in error && - typeof error.message === 'string' - ) { - details = error.message; - } - - return RequestError.internalError(details).toResult(); - } - } - - #handleResponse(response: AnyResponse) { - const pendingResponse = this.#pendingResponses.get(response.id); - if (pendingResponse) { - if ('result' in response) { - pendingResponse.resolve(response.result); - } else if ('error' in response) { - pendingResponse.reject(response.error); - } - this.#pendingResponses.delete(response.id); - } - } - - async sendRequest<Req, Resp>(method: string, params?: Req): Promise<Resp> { - const id = this.#nextRequestId++; - const responsePromise = new Promise((resolve, reject) => { - this.#pendingResponses.set(id, { resolve, reject }); - }); - await this.#sendMessage({ jsonrpc: '2.0', id, method, params }); - return responsePromise as Promise<Resp>; - } - - async #sendMessage(json: AnyMessage) { - const content = JSON.stringify(json) + '\n'; - this.#writeQueue = this.#writeQueue - .then(async () => { - const writer = this.#peerInput.getWriter(); - try { - await writer.write(this.#textEncoder.encode(content)); - } finally { - writer.releaseLock(); - } - }) - .catch((error) => { - // Continue processing writes on error - console.error('ACP write error:', error); - }); - return this.#writeQueue; - } -} - -export class RequestError extends Error { - data?: { details?: string }; - - constructor( - public code: number, - message: string, - details?: string, - ) { - super(message); - this.name = 'RequestError'; - if (details) { - this.data = { details }; - } - } - - static parseError(details?: string): RequestError { - return new RequestError(-32700, 'Parse error', details); - } - - static invalidRequest(details?: string): RequestError { - return new RequestError(-32600, 'Invalid request', details); - } - - static methodNotFound(details?: string): RequestError { - return new RequestError(-32601, 'Method not found', details); - } - - static invalidParams(details?: string): RequestError { - return new RequestError(-32602, 'Invalid params', details); - } - - static internalError(details?: string): RequestError { - return new RequestError(-32603, 'Internal error', details); - } - - toResult<T>(): Result<T> { - return { - error: { - code: this.code, - message: this.message, - data: this.data, - }, - }; - } -} - -// Protocol types - -export const LATEST_PROTOCOL_VERSION = '0.0.9'; - -export type AssistantMessageChunk = - | { - text: string; - } - | { - thought: string; - }; - -export type ToolCallConfirmation = - | { - description?: string | null; - type: 'edit'; - } - | { - description?: string | null; - type: 'execute'; - command: string; - rootCommand: string; - } - | { - description?: string | null; - type: 'mcp'; - serverName: string; - toolDisplayName: string; - toolName: string; - } - | { - description?: string | null; - type: 'fetch'; - urls: string[]; - } - | { - description: string; - type: 'other'; - }; - -export type ToolCallContent = - | { - type: 'markdown'; - markdown: string; - } - | { - type: 'diff'; - newText: string; - oldText: string | null; - path: string; - }; - -export type ToolCallStatus = 'running' | 'finished' | 'error'; - -export type ToolCallId = number; - -export type ToolCallConfirmationOutcome = - | 'allow' - | 'alwaysAllow' - | 'alwaysAllowMcpServer' - | 'alwaysAllowTool' - | 'reject' - | 'cancel'; - -/** - * A part in a user message - */ -export type UserMessageChunk = - | { - text: string; - } - | { - path: string; - }; - -export interface StreamAssistantMessageChunkParams { - chunk: AssistantMessageChunk; -} - -export interface RequestToolCallConfirmationParams { - confirmation: ToolCallConfirmation; - content?: ToolCallContent | null; - icon: Icon; - label: string; - locations?: ToolCallLocation[]; -} - -export interface ToolCallLocation { - line?: number | null; - path: string; -} - -export interface PushToolCallParams { - content?: ToolCallContent | null; - icon: Icon; - label: string; - locations?: ToolCallLocation[]; -} - -export interface UpdateToolCallParams { - content: ToolCallContent | null; - status: ToolCallStatus; - toolCallId: ToolCallId; -} - -export interface RequestToolCallConfirmationResponse { - id: ToolCallId; - outcome: ToolCallConfirmationOutcome; -} - -export interface PushToolCallResponse { - id: ToolCallId; -} - -export interface InitializeParams { - /** - * The version of the protocol that the client supports. - * This should be the latest version supported by the client. - */ - protocolVersion: string; -} - -export interface SendUserMessageParams { - chunks: UserMessageChunk[]; -} - -export interface InitializeResponse { - /** - * Indicates whether the agent is authenticated and - * ready to handle requests. - */ - isAuthenticated: boolean; - /** - * The version of the protocol that the agent supports. - * If the agent supports the requested version, it should respond with the same version. - * Otherwise, the agent should respond with the latest version it supports. - */ - protocolVersion: string; -} - -export interface Error { - code: number; - data?: unknown; - message: string; -} - -export interface Client { - streamAssistantMessageChunk( - params: StreamAssistantMessageChunkParams, - ): Promise<void>; - - requestToolCallConfirmation( - params: RequestToolCallConfirmationParams, - ): Promise<RequestToolCallConfirmationResponse>; - - pushToolCall(params: PushToolCallParams): Promise<PushToolCallResponse>; - - updateToolCall(params: UpdateToolCallParams): Promise<void>; -} - -export interface Agent { - /** - * Initializes the agent's state. It should be called before any other method, - * and no other methods should be called until it has completed. - * - * If the agent is not authenticated, then the client should prompt the user to authenticate, - * and then call the `authenticate` method. - * Otherwise the client can send other messages to the agent. - */ - initialize(params: InitializeParams): Promise<InitializeResponse>; - - /** - * Begins the authentication process. - * - * This method should only be called if `initialize` indicates the user isn't already authenticated. - * The Promise MUST not resolve until authentication is complete. - */ - authenticate(): Promise<void>; - - /** - * Allows the user to send a message to the agent. - * This method should complete after the agent is finished, during - * which time the agent may update the client by calling - * streamAssistantMessageChunk and other methods. - */ - sendUserMessage(params: SendUserMessageParams): Promise<void>; - - /** - * Cancels the current generation. - */ - cancelSendMessage(): Promise<void>; -} diff --git a/packages/cli/src/config/config.ts b/packages/cli/src/config/config.ts index dd207ff2..636696fa 100644 --- a/packages/cli/src/config/config.ts +++ b/packages/cli/src/config/config.ts @@ -304,6 +304,7 @@ export async function loadCliConfig( extensions: Extension[], sessionId: string, argv: CliArgs, + cwd: string = process.cwd(), ): Promise<Config> { const debugMode = argv.debug || @@ -343,7 +344,7 @@ export async function loadCliConfig( (e) => e.contextFiles, ); - const fileService = new FileDiscoveryService(process.cwd()); + const fileService = new FileDiscoveryService(cwd); const fileFiltering = { ...DEFAULT_MEMORY_FILE_FILTERING_OPTIONS, @@ -356,7 +357,7 @@ export async function loadCliConfig( // Call the (now wrapper) loadHierarchicalGeminiMemory which calls the server's version const { memoryContent, fileCount } = await loadHierarchicalGeminiMemory( - process.cwd(), + cwd, settings.loadMemoryFromIncludeDirectories ? includeDirectories : [], debugMode, fileService, @@ -398,7 +399,7 @@ export async function loadCliConfig( !!argv.promptInteractive || (process.stdin.isTTY && question.length === 0); // In non-interactive mode, exclude tools that require a prompt. const extraExcludes: string[] = []; - if (!interactive) { + if (!interactive && !argv.experimentalAcp) { switch (approvalMode) { case ApprovalMode.DEFAULT: // In default non-interactive mode, all tools that require approval are excluded. @@ -457,7 +458,7 @@ export async function loadCliConfig( sessionId, embeddingModel: DEFAULT_GEMINI_EMBEDDING_MODEL, sandbox: sandboxConfig, - targetDir: process.cwd(), + targetDir: cwd, includeDirectories, loadMemoryFromIncludeDirectories: settings.loadMemoryFromIncludeDirectories || false, @@ -505,13 +506,13 @@ export async function loadCliConfig( process.env.https_proxy || process.env.HTTP_PROXY || process.env.http_proxy, - cwd: process.cwd(), + cwd, fileDiscoveryService: fileService, bugCommand: settings.bugCommand, model: argv.model || settings.model || DEFAULT_GEMINI_MODEL, extensionContextFilePaths, maxSessionTurns: settings.maxSessionTurns ?? -1, - experimentalAcp: argv.experimentalAcp || false, + experimentalZedIntegration: argv.experimentalAcp || false, listExtensions: argv.listExtensions || false, extensions: allExtensions, blockedMcpServers, diff --git a/packages/cli/src/gemini.tsx b/packages/cli/src/gemini.tsx index acc9c4b2..68f948da 100644 --- a/packages/cli/src/gemini.tsx +++ b/packages/cli/src/gemini.tsx @@ -106,7 +106,7 @@ async function relaunchWithAdditionalArgs(additionalArgs: string[]) { await new Promise((resolve) => child.on('close', resolve)); process.exit(0); } -import { runAcpPeer } from './acp/acpPeer.js'; +import { runZedIntegration } from './zed-integration/zedIntegration.js'; export function setupUnhandledRejectionHandler() { let unhandledRejectionOccurred = false; @@ -250,8 +250,8 @@ export async function main() { await getOauthClient(settings.merged.selectedAuthType, config); } - if (config.getExperimentalAcp()) { - return runAcpPeer(config, settings); + if (config.getExperimentalZedIntegration()) { + return runZedIntegration(config, settings, extensions, argv); } let input = config.getQuestion(); diff --git a/packages/cli/src/ui/hooks/useToolScheduler.test.ts b/packages/cli/src/ui/hooks/useToolScheduler.test.ts index ee5251d3..64b064e2 100644 --- a/packages/cli/src/ui/hooks/useToolScheduler.test.ts +++ b/packages/cli/src/ui/hooks/useToolScheduler.test.ts @@ -23,7 +23,7 @@ import { ToolCall, // Import from core Status as ToolCallStatusType, ApprovalMode, - Icon, + Kind, BaseTool, AnyDeclarativeTool, AnyToolInvocation, @@ -67,7 +67,7 @@ class MockTool extends BaseTool<object, ToolResult> { name, displayName, 'A mock tool for testing', - Icon.Hammer, + Kind.Other, {}, isOutputMarkdown, canUpdateOutput, diff --git a/packages/cli/src/zed-integration/acp.ts b/packages/cli/src/zed-integration/acp.ts new file mode 100644 index 00000000..eef4e1ee --- /dev/null +++ b/packages/cli/src/zed-integration/acp.ts @@ -0,0 +1,366 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +/* ACP defines a schema for a simple (experimental) JSON-RPC protocol that allows GUI applications to interact with agents. */ + +import { z } from 'zod'; +import * as schema from './schema.js'; +export * from './schema.js'; + +import { WritableStream, ReadableStream } from 'node:stream/web'; + +export class AgentSideConnection implements Client { + #connection: Connection; + + constructor( + toAgent: (conn: Client) => Agent, + input: WritableStream<Uint8Array>, + output: ReadableStream<Uint8Array>, + ) { + const agent = toAgent(this); + + const handler = async ( + method: string, + params: unknown, + ): Promise<unknown> => { + switch (method) { + case schema.AGENT_METHODS.initialize: { + const validatedParams = schema.initializeRequestSchema.parse(params); + return agent.initialize(validatedParams); + } + case schema.AGENT_METHODS.session_new: { + const validatedParams = schema.newSessionRequestSchema.parse(params); + return agent.newSession(validatedParams); + } + case schema.AGENT_METHODS.session_load: { + if (!agent.loadSession) { + throw RequestError.methodNotFound(); + } + const validatedParams = schema.loadSessionRequestSchema.parse(params); + return agent.loadSession(validatedParams); + } + case schema.AGENT_METHODS.authenticate: { + const validatedParams = + schema.authenticateRequestSchema.parse(params); + return agent.authenticate(validatedParams); + } + case schema.AGENT_METHODS.session_prompt: { + const validatedParams = schema.promptRequestSchema.parse(params); + return agent.prompt(validatedParams); + } + case schema.AGENT_METHODS.session_cancel: { + const validatedParams = schema.cancelNotificationSchema.parse(params); + return agent.cancel(validatedParams); + } + default: + throw RequestError.methodNotFound(method); + } + }; + + this.#connection = new Connection(handler, input, output); + } + + /** + * Streams new content to the client including text, tool calls, etc. + */ + async sessionUpdate(params: schema.SessionNotification): Promise<void> { + return await this.#connection.sendNotification( + schema.CLIENT_METHODS.session_update, + params, + ); + } + + /** + * Request permission before running a tool + * + * The agent specifies a series of permission options with different granularity, + * and the client returns the chosen one. + */ + async requestPermission( + params: schema.RequestPermissionRequest, + ): Promise<schema.RequestPermissionResponse> { + return await this.#connection.sendRequest( + schema.CLIENT_METHODS.session_request_permission, + params, + ); + } + + async readTextFile( + params: schema.ReadTextFileRequest, + ): Promise<schema.ReadTextFileResponse> { + return await this.#connection.sendRequest( + schema.CLIENT_METHODS.fs_read_text_file, + params, + ); + } + + async writeTextFile( + params: schema.WriteTextFileRequest, + ): Promise<schema.WriteTextFileResponse> { + return await this.#connection.sendRequest( + schema.CLIENT_METHODS.fs_write_text_file, + params, + ); + } +} + +type AnyMessage = AnyRequest | AnyResponse | AnyNotification; + +type AnyRequest = { + jsonrpc: '2.0'; + id: string | number; + method: string; + params?: unknown; +}; + +type AnyResponse = { + jsonrpc: '2.0'; + id: string | number; +} & Result<unknown>; + +type AnyNotification = { + jsonrpc: '2.0'; + method: string; + params?: unknown; +}; + +type Result<T> = + | { + result: T; + } + | { + error: ErrorResponse; + }; + +type ErrorResponse = { + code: number; + message: string; + data?: unknown; +}; + +type PendingResponse = { + resolve: (response: unknown) => void; + reject: (error: ErrorResponse) => void; +}; + +type MethodHandler = (method: string, params: unknown) => Promise<unknown>; + +class Connection { + #pendingResponses: Map<string | number, PendingResponse> = new Map(); + #nextRequestId: number = 0; + #handler: MethodHandler; + #peerInput: WritableStream<Uint8Array>; + #writeQueue: Promise<void> = Promise.resolve(); + #textEncoder: TextEncoder; + + constructor( + handler: MethodHandler, + peerInput: WritableStream<Uint8Array>, + peerOutput: ReadableStream<Uint8Array>, + ) { + this.#handler = handler; + this.#peerInput = peerInput; + this.#textEncoder = new TextEncoder(); + this.#receive(peerOutput); + } + + async #receive(output: ReadableStream<Uint8Array>) { + let content = ''; + const decoder = new TextDecoder(); + for await (const chunk of output) { + content += decoder.decode(chunk, { stream: true }); + const lines = content.split('\n'); + content = lines.pop() || ''; + + for (const line of lines) { + const trimmedLine = line.trim(); + + if (trimmedLine) { + const message = JSON.parse(trimmedLine); + this.#processMessage(message); + } + } + } + } + + async #processMessage(message: AnyMessage) { + if ('method' in message && 'id' in message) { + // It's a request + const response = await this.#tryCallHandler( + message.method, + message.params, + ); + + await this.#sendMessage({ + jsonrpc: '2.0', + id: message.id, + ...response, + }); + } else if ('method' in message) { + // It's a notification + await this.#tryCallHandler(message.method, message.params); + } else if ('id' in message) { + // It's a response + this.#handleResponse(message as AnyResponse); + } + } + + async #tryCallHandler( + method: string, + params?: unknown, + ): Promise<Result<unknown>> { + try { + const result = await this.#handler(method, params); + return { result: result ?? null }; + } catch (error: unknown) { + if (error instanceof RequestError) { + return error.toResult(); + } + + if (error instanceof z.ZodError) { + return RequestError.invalidParams( + JSON.stringify(error.format(), undefined, 2), + ).toResult(); + } + + let details; + + if (error instanceof Error) { + details = error.message; + } else if ( + typeof error === 'object' && + error != null && + 'message' in error && + typeof error.message === 'string' + ) { + details = error.message; + } + + return RequestError.internalError(details).toResult(); + } + } + + #handleResponse(response: AnyResponse) { + const pendingResponse = this.#pendingResponses.get(response.id); + if (pendingResponse) { + if ('result' in response) { + pendingResponse.resolve(response.result); + } else if ('error' in response) { + pendingResponse.reject(response.error); + } + this.#pendingResponses.delete(response.id); + } + } + + async sendRequest<Req, Resp>(method: string, params?: Req): Promise<Resp> { + const id = this.#nextRequestId++; + const responsePromise = new Promise((resolve, reject) => { + this.#pendingResponses.set(id, { resolve, reject }); + }); + await this.#sendMessage({ jsonrpc: '2.0', id, method, params }); + return responsePromise as Promise<Resp>; + } + + async sendNotification<N>(method: string, params?: N): Promise<void> { + await this.#sendMessage({ jsonrpc: '2.0', method, params }); + } + + async #sendMessage(json: AnyMessage) { + const content = JSON.stringify(json) + '\n'; + this.#writeQueue = this.#writeQueue + .then(async () => { + const writer = this.#peerInput.getWriter(); + try { + await writer.write(this.#textEncoder.encode(content)); + } finally { + writer.releaseLock(); + } + }) + .catch((error) => { + // Continue processing writes on error + console.error('ACP write error:', error); + }); + return this.#writeQueue; + } +} + +export class RequestError extends Error { + data?: { details?: string }; + + constructor( + public code: number, + message: string, + details?: string, + ) { + super(message); + this.name = 'RequestError'; + if (details) { + this.data = { details }; + } + } + + static parseError(details?: string): RequestError { + return new RequestError(-32700, 'Parse error', details); + } + + static invalidRequest(details?: string): RequestError { + return new RequestError(-32600, 'Invalid request', details); + } + + static methodNotFound(details?: string): RequestError { + return new RequestError(-32601, 'Method not found', details); + } + + static invalidParams(details?: string): RequestError { + return new RequestError(-32602, 'Invalid params', details); + } + + static internalError(details?: string): RequestError { + return new RequestError(-32603, 'Internal error', details); + } + + static authRequired(details?: string): RequestError { + return new RequestError(-32000, 'Authentication required', details); + } + + toResult<T>(): Result<T> { + return { + error: { + code: this.code, + message: this.message, + data: this.data, + }, + }; + } +} + +export interface Client { + requestPermission( + params: schema.RequestPermissionRequest, + ): Promise<schema.RequestPermissionResponse>; + sessionUpdate(params: schema.SessionNotification): Promise<void>; + writeTextFile( + params: schema.WriteTextFileRequest, + ): Promise<schema.WriteTextFileResponse>; + readTextFile( + params: schema.ReadTextFileRequest, + ): Promise<schema.ReadTextFileResponse>; +} + +export interface Agent { + initialize( + params: schema.InitializeRequest, + ): Promise<schema.InitializeResponse>; + newSession( + params: schema.NewSessionRequest, + ): Promise<schema.NewSessionResponse>; + loadSession?( + params: schema.LoadSessionRequest, + ): Promise<schema.LoadSessionResponse>; + authenticate(params: schema.AuthenticateRequest): Promise<void>; + prompt(params: schema.PromptRequest): Promise<schema.PromptResponse>; + cancel(params: schema.CancelNotification): Promise<void>; +} diff --git a/packages/cli/src/zed-integration/schema.ts b/packages/cli/src/zed-integration/schema.ts new file mode 100644 index 00000000..4c962131 --- /dev/null +++ b/packages/cli/src/zed-integration/schema.ts @@ -0,0 +1,457 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { z } from 'zod'; + +export const AGENT_METHODS = { + authenticate: 'authenticate', + initialize: 'initialize', + session_cancel: 'session/cancel', + session_load: 'session/load', + session_new: 'session/new', + session_prompt: 'session/prompt', +}; + +export const CLIENT_METHODS = { + fs_read_text_file: 'fs/read_text_file', + fs_write_text_file: 'fs/write_text_file', + session_request_permission: 'session/request_permission', + session_update: 'session/update', +}; + +export const PROTOCOL_VERSION = 1; + +export type WriteTextFileRequest = z.infer<typeof writeTextFileRequestSchema>; + +export type ReadTextFileRequest = z.infer<typeof readTextFileRequestSchema>; + +export type PermissionOptionKind = z.infer<typeof permissionOptionKindSchema>; + +export type Role = z.infer<typeof roleSchema>; + +export type TextResourceContents = z.infer<typeof textResourceContentsSchema>; + +export type BlobResourceContents = z.infer<typeof blobResourceContentsSchema>; + +export type ToolKind = z.infer<typeof toolKindSchema>; + +export type ToolCallStatus = z.infer<typeof toolCallStatusSchema>; + +export type WriteTextFileResponse = z.infer<typeof writeTextFileResponseSchema>; + +export type ReadTextFileResponse = z.infer<typeof readTextFileResponseSchema>; + +export type RequestPermissionOutcome = z.infer< + typeof requestPermissionOutcomeSchema +>; + +export type CancelNotification = z.infer<typeof cancelNotificationSchema>; + +export type AuthenticateRequest = z.infer<typeof authenticateRequestSchema>; + +export type AuthenticateResponse = z.infer<typeof authenticateResponseSchema>; + +export type NewSessionResponse = z.infer<typeof newSessionResponseSchema>; + +export type LoadSessionResponse = z.infer<typeof loadSessionResponseSchema>; + +export type StopReason = z.infer<typeof stopReasonSchema>; + +export type PromptResponse = z.infer<typeof promptResponseSchema>; + +export type ToolCallLocation = z.infer<typeof toolCallLocationSchema>; + +export type PlanEntry = z.infer<typeof planEntrySchema>; + +export type PermissionOption = z.infer<typeof permissionOptionSchema>; + +export type Annotations = z.infer<typeof annotationsSchema>; + +export type RequestPermissionResponse = z.infer< + typeof requestPermissionResponseSchema +>; + +export type FileSystemCapability = z.infer<typeof fileSystemCapabilitySchema>; + +export type EnvVariable = z.infer<typeof envVariableSchema>; + +export type McpServer = z.infer<typeof mcpServerSchema>; + +export type AgentCapabilities = z.infer<typeof agentCapabilitiesSchema>; + +export type AuthMethod = z.infer<typeof authMethodSchema>; + +export type ClientResponse = z.infer<typeof clientResponseSchema>; + +export type ClientNotification = z.infer<typeof clientNotificationSchema>; + +export type EmbeddedResourceResource = z.infer< + typeof embeddedResourceResourceSchema +>; + +export type NewSessionRequest = z.infer<typeof newSessionRequestSchema>; + +export type LoadSessionRequest = z.infer<typeof loadSessionRequestSchema>; + +export type InitializeResponse = z.infer<typeof initializeResponseSchema>; + +export type ContentBlock = z.infer<typeof contentBlockSchema>; + +export type ToolCallContent = z.infer<typeof toolCallContentSchema>; + +export type ToolCall = z.infer<typeof toolCallSchema>; + +export type ClientCapabilities = z.infer<typeof clientCapabilitiesSchema>; + +export type PromptRequest = z.infer<typeof promptRequestSchema>; + +export type SessionUpdate = z.infer<typeof sessionUpdateSchema>; + +export type AgentResponse = z.infer<typeof agentResponseSchema>; + +export type RequestPermissionRequest = z.infer< + typeof requestPermissionRequestSchema +>; + +export type InitializeRequest = z.infer<typeof initializeRequestSchema>; + +export type SessionNotification = z.infer<typeof sessionNotificationSchema>; + +export type ClientRequest = z.infer<typeof clientRequestSchema>; + +export type AgentRequest = z.infer<typeof agentRequestSchema>; + +export type AgentNotification = z.infer<typeof agentNotificationSchema>; + +export const writeTextFileRequestSchema = z.object({ + content: z.string(), + path: z.string(), + sessionId: z.string(), +}); + +export const readTextFileRequestSchema = z.object({ + limit: z.number().optional().nullable(), + line: z.number().optional().nullable(), + path: z.string(), + sessionId: z.string(), +}); + +export const permissionOptionKindSchema = z.union([ + z.literal('allow_once'), + z.literal('allow_always'), + z.literal('reject_once'), + z.literal('reject_always'), +]); + +export const roleSchema = z.union([z.literal('assistant'), z.literal('user')]); + +export const textResourceContentsSchema = z.object({ + mimeType: z.string().optional().nullable(), + text: z.string(), + uri: z.string(), +}); + +export const blobResourceContentsSchema = z.object({ + blob: z.string(), + mimeType: z.string().optional().nullable(), + uri: z.string(), +}); + +export const toolKindSchema = z.union([ + z.literal('read'), + z.literal('edit'), + z.literal('delete'), + z.literal('move'), + z.literal('search'), + z.literal('execute'), + z.literal('think'), + z.literal('fetch'), + z.literal('other'), +]); + +export const toolCallStatusSchema = z.union([ + z.literal('pending'), + z.literal('in_progress'), + z.literal('completed'), + z.literal('failed'), +]); + +export const writeTextFileResponseSchema = z.null(); + +export const readTextFileResponseSchema = z.object({ + content: z.string(), +}); + +export const requestPermissionOutcomeSchema = z.union([ + z.object({ + outcome: z.literal('cancelled'), + }), + z.object({ + optionId: z.string(), + outcome: z.literal('selected'), + }), +]); + +export const cancelNotificationSchema = z.object({ + sessionId: z.string(), +}); + +export const authenticateRequestSchema = z.object({ + methodId: z.string(), +}); + +export const authenticateResponseSchema = z.null(); + +export const newSessionResponseSchema = z.object({ + sessionId: z.string(), +}); + +export const loadSessionResponseSchema = z.null(); + +export const stopReasonSchema = z.union([ + z.literal('end_turn'), + z.literal('max_tokens'), + z.literal('refusal'), + z.literal('cancelled'), +]); + +export const promptResponseSchema = z.object({ + stopReason: stopReasonSchema, +}); + +export const toolCallLocationSchema = z.object({ + line: z.number().optional().nullable(), + path: z.string(), +}); + +export const planEntrySchema = z.object({ + content: z.string(), + priority: z.union([z.literal('high'), z.literal('medium'), z.literal('low')]), + status: z.union([ + z.literal('pending'), + z.literal('in_progress'), + z.literal('completed'), + ]), +}); + +export const permissionOptionSchema = z.object({ + kind: permissionOptionKindSchema, + name: z.string(), + optionId: z.string(), +}); + +export const annotationsSchema = z.object({ + audience: z.array(roleSchema).optional().nullable(), + lastModified: z.string().optional().nullable(), + priority: z.number().optional().nullable(), +}); + +export const requestPermissionResponseSchema = z.object({ + outcome: requestPermissionOutcomeSchema, +}); + +export const fileSystemCapabilitySchema = z.object({ + readTextFile: z.boolean(), + writeTextFile: z.boolean(), +}); + +export const envVariableSchema = z.object({ + name: z.string(), + value: z.string(), +}); + +export const mcpServerSchema = z.object({ + args: z.array(z.string()), + command: z.string(), + env: z.array(envVariableSchema), + name: z.string(), +}); + +export const agentCapabilitiesSchema = z.object({ + loadSession: z.boolean(), +}); + +export const authMethodSchema = z.object({ + description: z.string().nullable(), + id: z.string(), + name: z.string(), +}); + +export const clientResponseSchema = z.union([ + writeTextFileResponseSchema, + readTextFileResponseSchema, + requestPermissionResponseSchema, +]); + +export const clientNotificationSchema = cancelNotificationSchema; + +export const embeddedResourceResourceSchema = z.union([ + textResourceContentsSchema, + blobResourceContentsSchema, +]); + +export const newSessionRequestSchema = z.object({ + cwd: z.string(), + mcpServers: z.array(mcpServerSchema), +}); + +export const loadSessionRequestSchema = z.object({ + cwd: z.string(), + mcpServers: z.array(mcpServerSchema), + sessionId: z.string(), +}); + +export const initializeResponseSchema = z.object({ + agentCapabilities: agentCapabilitiesSchema, + authMethods: z.array(authMethodSchema), + protocolVersion: z.number(), +}); + +export const contentBlockSchema = z.union([ + z.object({ + annotations: annotationsSchema.optional().nullable(), + text: z.string(), + type: z.literal('text'), + }), + z.object({ + annotations: annotationsSchema.optional().nullable(), + data: z.string(), + mimeType: z.string(), + type: z.literal('image'), + }), + z.object({ + annotations: annotationsSchema.optional().nullable(), + data: z.string(), + mimeType: z.string(), + type: z.literal('audio'), + }), + z.object({ + annotations: annotationsSchema.optional().nullable(), + description: z.string().optional().nullable(), + mimeType: z.string().optional().nullable(), + name: z.string(), + size: z.number().optional().nullable(), + title: z.string().optional().nullable(), + type: z.literal('resource_link'), + uri: z.string(), + }), + z.object({ + annotations: annotationsSchema.optional().nullable(), + resource: embeddedResourceResourceSchema, + type: z.literal('resource'), + }), +]); + +export const toolCallContentSchema = z.union([ + z.object({ + content: contentBlockSchema, + type: z.literal('content'), + }), + z.object({ + newText: z.string(), + oldText: z.string().nullable(), + path: z.string(), + type: z.literal('diff'), + }), +]); + +export const toolCallSchema = z.object({ + content: z.array(toolCallContentSchema).optional(), + kind: toolKindSchema, + locations: z.array(toolCallLocationSchema).optional(), + rawInput: z.unknown().optional(), + status: toolCallStatusSchema, + title: z.string(), + toolCallId: z.string(), +}); + +export const clientCapabilitiesSchema = z.object({ + fs: fileSystemCapabilitySchema, +}); + +export const promptRequestSchema = z.object({ + prompt: z.array(contentBlockSchema), + sessionId: z.string(), +}); + +export const sessionUpdateSchema = z.union([ + z.object({ + content: contentBlockSchema, + sessionUpdate: z.literal('user_message_chunk'), + }), + z.object({ + content: contentBlockSchema, + sessionUpdate: z.literal('agent_message_chunk'), + }), + z.object({ + content: contentBlockSchema, + sessionUpdate: z.literal('agent_thought_chunk'), + }), + z.object({ + content: z.array(toolCallContentSchema).optional(), + kind: toolKindSchema, + locations: z.array(toolCallLocationSchema).optional(), + rawInput: z.unknown().optional(), + sessionUpdate: z.literal('tool_call'), + status: toolCallStatusSchema, + title: z.string(), + toolCallId: z.string(), + }), + z.object({ + content: z.array(toolCallContentSchema).optional().nullable(), + kind: toolKindSchema.optional().nullable(), + locations: z.array(toolCallLocationSchema).optional().nullable(), + rawInput: z.unknown().optional(), + sessionUpdate: z.literal('tool_call_update'), + status: toolCallStatusSchema.optional().nullable(), + title: z.string().optional().nullable(), + toolCallId: z.string(), + }), + z.object({ + entries: z.array(planEntrySchema), + sessionUpdate: z.literal('plan'), + }), +]); + +export const agentResponseSchema = z.union([ + initializeResponseSchema, + authenticateResponseSchema, + newSessionResponseSchema, + loadSessionResponseSchema, + promptResponseSchema, +]); + +export const requestPermissionRequestSchema = z.object({ + options: z.array(permissionOptionSchema), + sessionId: z.string(), + toolCall: toolCallSchema, +}); + +export const initializeRequestSchema = z.object({ + clientCapabilities: clientCapabilitiesSchema, + protocolVersion: z.number(), +}); + +export const sessionNotificationSchema = z.object({ + sessionId: z.string(), + update: sessionUpdateSchema, +}); + +export const clientRequestSchema = z.union([ + writeTextFileRequestSchema, + readTextFileRequestSchema, + requestPermissionRequestSchema, +]); + +export const agentRequestSchema = z.union([ + initializeRequestSchema, + authenticateRequestSchema, + newSessionRequestSchema, + loadSessionRequestSchema, + promptRequestSchema, +]); + +export const agentNotificationSchema = sessionNotificationSchema; diff --git a/packages/cli/src/acp/acpPeer.ts b/packages/cli/src/zed-integration/zedIntegration.ts index 40d8753f..1b5baa8a 100644 --- a/packages/cli/src/acp/acpPeer.ts +++ b/packages/cli/src/zed-integration/zedIntegration.ts @@ -21,16 +21,26 @@ import { getErrorMessage, isWithinRoot, getErrorStatus, + MCPServerConfig, } from '@google/gemini-cli-core'; import * as acp from './acp.js'; -import { Agent } from './acp.js'; import { Readable, Writable } from 'node:stream'; import { Content, Part, FunctionCall, PartListUnion } from '@google/genai'; import { LoadedSettings, SettingScope } from '../config/settings.js'; import * as fs from 'fs/promises'; import * as path from 'path'; +import { z } from 'zod'; -export async function runAcpPeer(config: Config, settings: LoadedSettings) { +import { randomUUID } from 'crypto'; +import { Extension } from '../config/extension.js'; +import { CliArgs, loadCliConfig } from '../config/config.js'; + +export async function runZedIntegration( + config: Config, + settings: LoadedSettings, + extensions: Extension[], + argv: CliArgs, +) { const stdout = Writable.toWeb(process.stdout) as WritableStream; const stdin = Readable.toWeb(process.stdin) as ReadableStream<Uint8Array>; @@ -40,76 +50,176 @@ export async function runAcpPeer(config: Config, settings: LoadedSettings) { console.info = console.error; console.debug = console.error; - new acp.ClientConnection( - (client: acp.Client) => new GeminiAgent(config, settings, client), + new acp.AgentSideConnection( + (client: acp.Client) => + new GeminiAgent(config, settings, extensions, argv, client), stdout, stdin, ); } -class GeminiAgent implements Agent { - chat?: GeminiChat; - pendingSend?: AbortController; +class GeminiAgent { + private sessions: Map<string, Session> = new Map(); constructor( private config: Config, private settings: LoadedSettings, + private extensions: Extension[], + private argv: CliArgs, private client: acp.Client, ) {} - async initialize(_: acp.InitializeParams): Promise<acp.InitializeResponse> { + async initialize( + _args: acp.InitializeRequest, + ): Promise<acp.InitializeResponse> { + const authMethods = [ + { + id: AuthType.LOGIN_WITH_GOOGLE, + name: 'Log in with Google', + description: null, + }, + { + id: AuthType.USE_GEMINI, + name: 'Use Gemini API key', + description: + 'Requires setting the `GEMINI_API_KEY` environment variable', + }, + { + id: AuthType.USE_VERTEX_AI, + name: 'Vertex AI', + description: null, + }, + ]; + + return { + protocolVersion: acp.PROTOCOL_VERSION, + authMethods, + agentCapabilities: { + loadSession: false, + }, + }; + } + + async authenticate({ methodId }: acp.AuthenticateRequest): Promise<void> { + const method = z.nativeEnum(AuthType).parse(methodId); + + await clearCachedCredentialFile(); + await this.config.refreshAuth(method); + this.settings.setValue(SettingScope.User, 'selectedAuthType', method); + } + + async newSession({ + cwd, + mcpServers, + }: acp.NewSessionRequest): Promise<acp.NewSessionResponse> { + const sessionId = randomUUID(); + const config = await this.newSessionConfig(sessionId, cwd, mcpServers); + let isAuthenticated = false; if (this.settings.merged.selectedAuthType) { try { - await this.config.refreshAuth(this.settings.merged.selectedAuthType); + await config.refreshAuth(this.settings.merged.selectedAuthType); isAuthenticated = true; - } catch (error) { - console.error('Failed to refresh auth:', error); + } catch (e) { + console.error(`Authentication failed: ${e}`); } } - return { protocolVersion: acp.LATEST_PROTOCOL_VERSION, isAuthenticated }; + + if (!isAuthenticated) { + throw acp.RequestError.authRequired(); + } + + const geminiClient = config.getGeminiClient(); + const chat = await geminiClient.startChat(); + const session = new Session(sessionId, chat, config, this.client); + this.sessions.set(sessionId, session); + + return { + sessionId, + }; } - async authenticate(): Promise<void> { - await clearCachedCredentialFile(); - await this.config.refreshAuth(AuthType.LOGIN_WITH_GOOGLE); - this.settings.setValue( - SettingScope.User, - 'selectedAuthType', - AuthType.LOGIN_WITH_GOOGLE, + async newSessionConfig( + sessionId: string, + cwd: string, + mcpServers: acp.McpServer[], + ): Promise<Config> { + const mergedMcpServers = { ...this.settings.merged.mcpServers }; + + for (const { command, args, env: rawEnv, name } of mcpServers) { + const env: Record<string, string> = {}; + for (const { name: envName, value } of rawEnv) { + env[envName] = value; + } + mergedMcpServers[name] = new MCPServerConfig(command, args, env, cwd); + } + + const settings = { ...this.settings.merged, mcpServers: mergedMcpServers }; + + const config = await loadCliConfig( + settings, + this.extensions, + sessionId, + this.argv, + cwd, ); + + await config.initialize(); + return config; } - async cancelSendMessage(): Promise<void> { - if (!this.pendingSend) { - throw new Error('Not currently generating'); + async cancel(params: acp.CancelNotification): Promise<void> { + const session = this.sessions.get(params.sessionId); + if (!session) { + throw new Error(`Session not found: ${params.sessionId}`); } + await session.cancelPendingPrompt(); + } - this.pendingSend.abort(); - delete this.pendingSend; + async prompt(params: acp.PromptRequest): Promise<acp.PromptResponse> { + const session = this.sessions.get(params.sessionId); + if (!session) { + throw new Error(`Session not found: ${params.sessionId}`); + } + return session.prompt(params); } +} - async sendUserMessage(params: acp.SendUserMessageParams): Promise<void> { - this.pendingSend?.abort(); - const pendingSend = new AbortController(); - this.pendingSend = pendingSend; +class Session { + private pendingPrompt: AbortController | null = null; + + constructor( + private readonly id: string, + private readonly chat: GeminiChat, + private readonly config: Config, + private readonly client: acp.Client, + ) {} - if (!this.chat) { - const geminiClient = this.config.getGeminiClient(); - this.chat = await geminiClient.startChat(); + async cancelPendingPrompt(): Promise<void> { + if (!this.pendingPrompt) { + throw new Error('Not currently generating'); } + this.pendingPrompt.abort(); + this.pendingPrompt = null; + } + + async prompt(params: acp.PromptRequest): Promise<acp.PromptResponse> { + this.pendingPrompt?.abort(); + const pendingSend = new AbortController(); + this.pendingPrompt = pendingSend; + const promptId = Math.random().toString(16).slice(2); - const chat = this.chat!; - const toolRegistry: ToolRegistry = await this.config.getToolRegistry(); - const parts = await this.#resolveUserMessage(params, pendingSend.signal); + const chat = this.chat; + + const parts = await this.#resolvePrompt(params.prompt, pendingSend.signal); let nextMessage: Content | null = { role: 'user', parts }; while (nextMessage !== null) { if (pendingSend.signal.aborted) { chat.addHistory(nextMessage); - return; + return { stopReason: 'cancelled' }; } const functionCalls: FunctionCall[] = []; @@ -120,11 +230,6 @@ class GeminiAgent implements Agent { message: nextMessage?.parts ?? [], config: { abortSignal: pendingSend.signal, - tools: [ - { - functionDeclarations: toolRegistry.getFunctionDeclarations(), - }, - ], }, }, promptId, @@ -133,7 +238,7 @@ class GeminiAgent implements Agent { for await (const resp of responseStream) { if (pendingSend.signal.aborted) { - return; + return { stopReason: 'cancelled' }; } if (resp.candidates && resp.candidates.length > 0) { @@ -143,10 +248,16 @@ class GeminiAgent implements Agent { continue; } - this.client.streamAssistantMessageChunk({ - chunk: part.thought - ? { thought: part.text } - : { text: part.text }, + const content: acp.ContentBlock = { + type: 'text', + text: part.text, + }; + + this.sendUpdate({ + sessionUpdate: part.thought + ? 'agent_thought_chunk' + : 'agent_message_chunk', + content, }); } } @@ -170,11 +281,7 @@ class GeminiAgent implements Agent { const toolResponseParts: Part[] = []; for (const fc of functionCalls) { - const response = await this.#runTool( - pendingSend.signal, - promptId, - fc, - ); + const response = await this.runTool(pendingSend.signal, promptId, fc); const parts = Array.isArray(response) ? response : [response]; @@ -190,9 +297,20 @@ class GeminiAgent implements Agent { nextMessage = { role: 'user', parts: toolResponseParts }; } } + + return { stopReason: 'end_turn' }; + } + + private async sendUpdate(update: acp.SessionUpdate): Promise<void> { + const params: acp.SessionNotification = { + sessionId: this.id, + update, + }; + + await this.client.sessionUpdate(params); } - async #runTool( + private async runTool( abortSignal: AbortSignal, promptId: string, fc: FunctionCall, @@ -239,68 +357,82 @@ class GeminiAgent implements Agent { ); } - let toolCallId: number | undefined = undefined; - try { - const invocation = tool.build(args); - const confirmationDetails = - await invocation.shouldConfirmExecute(abortSignal); - if (confirmationDetails) { - let content: acp.ToolCallContent | null = null; - if (confirmationDetails.type === 'edit') { - content = { - type: 'diff', - path: confirmationDetails.fileName, - oldText: confirmationDetails.originalContent, - newText: confirmationDetails.newContent, - }; - } + const invocation = tool.build(args); + const confirmationDetails = + await invocation.shouldConfirmExecute(abortSignal); + + if (confirmationDetails) { + const content: acp.ToolCallContent[] = []; - const result = await this.client.requestToolCallConfirmation({ - label: invocation.getDescription(), - icon: tool.icon, + if (confirmationDetails.type === 'edit') { + content.push({ + type: 'diff', + path: confirmationDetails.fileName, + oldText: confirmationDetails.originalContent, + newText: confirmationDetails.newContent, + }); + } + + const params: acp.RequestPermissionRequest = { + sessionId: this.id, + options: toPermissionOptions(confirmationDetails), + toolCall: { + toolCallId: callId, + status: 'pending', + title: invocation.getDescription(), content, - confirmation: toAcpToolCallConfirmation(confirmationDetails), locations: invocation.toolLocations(), - }); + kind: tool.kind, + }, + }; - await confirmationDetails.onConfirm(toToolCallOutcome(result.outcome)); - switch (result.outcome) { - case 'reject': - return errorResponse( - new Error(`Tool "${fc.name}" not allowed to run by the user.`), - ); + const output = await this.client.requestPermission(params); + const outcome = + output.outcome.outcome === 'cancelled' + ? ToolConfirmationOutcome.Cancel + : z + .nativeEnum(ToolConfirmationOutcome) + .parse(output.outcome.optionId); - case 'cancel': - return errorResponse( - new Error(`Tool "${fc.name}" was canceled by the user.`), - ); - case 'allow': - case 'alwaysAllow': - case 'alwaysAllowMcpServer': - case 'alwaysAllowTool': - break; - default: { - const resultOutcome: never = result.outcome; - throw new Error(`Unexpected: ${resultOutcome}`); - } + await confirmationDetails.onConfirm(outcome); + + switch (outcome) { + case ToolConfirmationOutcome.Cancel: + return errorResponse( + new Error(`Tool "${fc.name}" was canceled by the user.`), + ); + case ToolConfirmationOutcome.ProceedOnce: + case ToolConfirmationOutcome.ProceedAlways: + case ToolConfirmationOutcome.ProceedAlwaysServer: + case ToolConfirmationOutcome.ProceedAlwaysTool: + case ToolConfirmationOutcome.ModifyWithEditor: + break; + default: { + const resultOutcome: never = outcome; + throw new Error(`Unexpected: ${resultOutcome}`); } - toolCallId = result.id; - } else { - const result = await this.client.pushToolCall({ - icon: tool.icon, - label: invocation.getDescription(), - locations: invocation.toolLocations(), - }); - toolCallId = result.id; } + } else { + await this.sendUpdate({ + sessionUpdate: 'tool_call', + toolCallId: callId, + status: 'in_progress', + title: invocation.getDescription(), + content: [], + locations: invocation.toolLocations(), + kind: tool.kind, + }); + } + try { const toolResult: ToolResult = await invocation.execute(abortSignal); - const toolCallContent = toToolCallContent(toolResult); + const content = toToolCallContent(toolResult); - await this.client.updateToolCall({ - toolCallId, - status: 'finished', - content: toolCallContent, + await this.sendUpdate({ + sessionUpdate: 'tool_call_update', + toolCallId: callId, + status: 'completed', + content: content ? [content] : [], }); const durationMs = Date.now() - startTime; @@ -317,31 +449,55 @@ class GeminiAgent implements Agent { return convertToFunctionResponse(fc.name, callId, toolResult.llmContent); } catch (e) { const error = e instanceof Error ? e : new Error(String(e)); - if (toolCallId) { - await this.client.updateToolCall({ - toolCallId, - status: 'error', - content: { type: 'markdown', markdown: error.message }, - }); - } + + await this.sendUpdate({ + sessionUpdate: 'tool_call_update', + toolCallId: callId, + status: 'failed', + content: [ + { type: 'content', content: { type: 'text', text: error.message } }, + ], + }); + return errorResponse(error); } } - async #resolveUserMessage( - message: acp.SendUserMessageParams, + async #resolvePrompt( + message: acp.ContentBlock[], abortSignal: AbortSignal, ): Promise<Part[]> { - const atPathCommandParts = message.chunks.filter((part) => 'path' in part); + const parts = message.map((part) => { + switch (part.type) { + case 'text': + return { text: part.text }; + case 'resource_link': + return { + fileData: { + mimeData: part.mimeType, + name: part.name, + fileUri: part.uri, + }, + }; + case 'resource': { + return { + fileData: { + mimeData: part.resource.mimeType, + name: part.resource.uri, + fileUri: part.resource.uri, + }, + }; + } + default: { + throw new Error(`Unexpected chunk type: '${part.type}'`); + } + } + }); + + const atPathCommandParts = parts.filter((part) => 'fileData' in part); if (atPathCommandParts.length === 0) { - return message.chunks.map((chunk) => { - if ('text' in chunk) { - return { text: chunk.text }; - } else { - throw new Error('Unexpected chunk type'); - } - }); + return parts; } // Get centralized file discovery service @@ -362,8 +518,7 @@ class GeminiAgent implements Agent { } for (const atPathPart of atPathCommandParts) { - const pathName = atPathPart.path; - + const pathName = atPathPart.fileData!.fileUri; // Check if path should be ignored by git if (fileDiscovery.shouldGitIgnoreFile(pathName)) { ignoredPaths.push(pathName); @@ -373,10 +528,8 @@ class GeminiAgent implements Agent { console.warn(`Path ${pathName} is ${reason}.`); continue; } - let currentPathSpec = pathName; let resolvedSuccessfully = false; - try { const absolutePath = path.resolve(this.config.getTargetDir(), pathName); if (isWithinRoot(absolutePath, this.config.getTargetDir())) { @@ -385,24 +538,22 @@ class GeminiAgent implements Agent { currentPathSpec = pathName.endsWith('/') ? `${pathName}**` : `${pathName}/**`; - this.#debug( + this.debug( `Path ${pathName} resolved to directory, using glob: ${currentPathSpec}`, ); } else { - this.#debug( - `Path ${pathName} resolved to file: ${currentPathSpec}`, - ); + this.debug(`Path ${pathName} resolved to file: ${currentPathSpec}`); } resolvedSuccessfully = true; } else { - this.#debug( + this.debug( `Path ${pathName} is outside the project directory. Skipping.`, ); } } catch (error) { if (isNodeError(error) && error.code === 'ENOENT') { if (this.config.getEnableRecursiveFileSearch() && globTool) { - this.#debug( + this.debug( `Path ${pathName} not found directly, attempting glob search.`, ); try { @@ -426,17 +577,17 @@ class GeminiAgent implements Agent { this.config.getTargetDir(), firstMatchAbsolute, ); - this.#debug( + this.debug( `Glob search for ${pathName} found ${firstMatchAbsolute}, using relative path: ${currentPathSpec}`, ); resolvedSuccessfully = true; } else { - this.#debug( + this.debug( `Glob search for '**/*${pathName}*' did not return a usable path. Path ${pathName} will be skipped.`, ); } } else { - this.#debug( + this.debug( `Glob search for '**/*${pathName}*' found no files or an error. Path ${pathName} will be skipped.`, ); } @@ -446,7 +597,7 @@ class GeminiAgent implements Agent { ); } } else { - this.#debug( + this.debug( `Glob tool not found. Path ${pathName} will be skipped.`, ); } @@ -456,23 +607,22 @@ class GeminiAgent implements Agent { ); } } - if (resolvedSuccessfully) { pathSpecsToRead.push(currentPathSpec); atPathToResolvedSpecMap.set(pathName, currentPathSpec); contentLabelsForDisplay.push(pathName); } } - // Construct the initial part of the query for the LLM let initialQueryText = ''; - for (let i = 0; i < message.chunks.length; i++) { - const chunk = message.chunks[i]; + for (let i = 0; i < parts.length; i++) { + const chunk = parts[i]; if ('text' in chunk) { initialQueryText += chunk.text; } else { // type === 'atPath' - const resolvedSpec = atPathToResolvedSpecMap.get(chunk.path); + const resolvedSpec = + chunk.fileData && atPathToResolvedSpecMap.get(chunk.fileData.fileUri); if ( i > 0 && initialQueryText.length > 0 && @@ -480,10 +630,11 @@ class GeminiAgent implements Agent { resolvedSpec ) { // Add space if previous part was text and didn't end with space, or if previous was @path - const prevPart = message.chunks[i - 1]; + const prevPart = parts[i - 1]; if ( 'text' in prevPart || - ('path' in prevPart && atPathToResolvedSpecMap.has(prevPart.path)) + ('fileData' in prevPart && + atPathToResolvedSpecMap.has(prevPart.fileData!.fileUri)) ) { initialQueryText += ' '; } @@ -497,56 +648,64 @@ class GeminiAgent implements Agent { i > 0 && initialQueryText.length > 0 && !initialQueryText.endsWith(' ') && - !chunk.path.startsWith(' ') + !chunk.fileData?.fileUri.startsWith(' ') ) { initialQueryText += ' '; } - initialQueryText += `@${chunk.path}`; + if (chunk.fileData?.fileUri) { + initialQueryText += `@${chunk.fileData.fileUri}`; + } } } } initialQueryText = initialQueryText.trim(); - // Inform user about ignored paths if (ignoredPaths.length > 0) { const ignoreType = respectGitIgnore ? 'git-ignored' : 'custom-ignored'; - this.#debug( + this.debug( `Ignored ${ignoredPaths.length} ${ignoreType} files: ${ignoredPaths.join(', ')}`, ); } - // Fallback for lone "@" or completely invalid @-commands resulting in empty initialQueryText if (pathSpecsToRead.length === 0) { console.warn('No valid file paths found in @ commands to read.'); return [{ text: initialQueryText }]; } - const processedQueryParts: Part[] = [{ text: initialQueryText }]; - const toolArgs = { paths: pathSpecsToRead, respectGitIgnore, // Use configuration setting }; - let toolCallId: number | undefined = undefined; + const callId = `${readManyFilesTool.name}-${Date.now()}`; + try { const invocation = readManyFilesTool.build(toolArgs); - const toolCall = await this.client.pushToolCall({ - icon: readManyFilesTool.icon, - label: invocation.getDescription(), + + await this.sendUpdate({ + sessionUpdate: 'tool_call', + toolCallId: callId, + status: 'in_progress', + title: invocation.getDescription(), + content: [], + locations: invocation.toolLocations(), + kind: readManyFilesTool.kind, }); - toolCallId = toolCall.id; + const result = await invocation.execute(abortSignal); const content = toToolCallContent(result) || { - type: 'markdown', - markdown: `Successfully read: ${contentLabelsForDisplay.join(', ')}`, + type: 'content', + content: { + type: 'text', + text: `Successfully read: ${contentLabelsForDisplay.join(', ')}`, + }, }; - await this.client.updateToolCall({ - toolCallId: toolCall.id, - status: 'finished', - content, + await this.sendUpdate({ + sessionUpdate: 'tool_call_update', + toolCallId: callId, + status: 'completed', + content: content ? [content] : [], }); - if (Array.isArray(result.llmContent)) { const fileContentRegex = /^--- (.*?) ---\n\n([\s\S]*?)\n\n$/; processedQueryParts.push({ @@ -576,24 +735,28 @@ class GeminiAgent implements Agent { 'read_many_files tool returned no content or empty content.', ); } - return processedQueryParts; } catch (error: unknown) { - if (toolCallId) { - await this.client.updateToolCall({ - toolCallId, - status: 'error', - content: { - type: 'markdown', - markdown: `Error reading files (${contentLabelsForDisplay.join(', ')}): ${getErrorMessage(error)}`, + await this.sendUpdate({ + sessionUpdate: 'tool_call_update', + toolCallId: callId, + status: 'failed', + content: [ + { + type: 'content', + content: { + type: 'text', + text: `Error reading files (${contentLabelsForDisplay.join(', ')}): ${getErrorMessage(error)}`, + }, }, - }); - } + ], + }); + throw error; } } - #debug(msg: string) { + debug(msg: string) { if (this.config.getDebugMode()) { console.warn(msg); } @@ -604,8 +767,8 @@ function toToolCallContent(toolResult: ToolResult): acp.ToolCallContent | null { if (toolResult.returnDisplay) { if (typeof toolResult.returnDisplay === 'string') { return { - type: 'markdown', - markdown: toolResult.returnDisplay, + type: 'content', + content: { type: 'text', text: toolResult.returnDisplay }, }; } else { return { @@ -620,57 +783,66 @@ function toToolCallContent(toolResult: ToolResult): acp.ToolCallContent | null { } } -function toAcpToolCallConfirmation( - confirmationDetails: ToolCallConfirmationDetails, -): acp.ToolCallConfirmation { - switch (confirmationDetails.type) { +const basicPermissionOptions = [ + { + optionId: ToolConfirmationOutcome.ProceedOnce, + name: 'Allow', + kind: 'allow_once', + }, + { + optionId: ToolConfirmationOutcome.Cancel, + name: 'Reject', + kind: 'reject_once', + }, +] as const; + +function toPermissionOptions( + confirmation: ToolCallConfirmationDetails, +): acp.PermissionOption[] { + switch (confirmation.type) { case 'edit': - return { type: 'edit' }; + return [ + { + optionId: ToolConfirmationOutcome.ProceedAlways, + name: 'Allow All Edits', + kind: 'allow_always', + }, + ...basicPermissionOptions, + ]; case 'exec': - return { - type: 'execute', - rootCommand: confirmationDetails.rootCommand, - command: confirmationDetails.command, - }; + return [ + { + optionId: ToolConfirmationOutcome.ProceedAlways, + name: `Always Allow ${confirmation.rootCommand}`, + kind: 'allow_always', + }, + ...basicPermissionOptions, + ]; case 'mcp': - return { - type: 'mcp', - serverName: confirmationDetails.serverName, - toolName: confirmationDetails.toolName, - toolDisplayName: confirmationDetails.toolDisplayName, - }; + return [ + { + optionId: ToolConfirmationOutcome.ProceedAlwaysServer, + name: `Always Allow ${confirmation.serverName}`, + kind: 'allow_always', + }, + { + optionId: ToolConfirmationOutcome.ProceedAlwaysTool, + name: `Always Allow ${confirmation.toolName}`, + kind: 'allow_always', + }, + ...basicPermissionOptions, + ]; case 'info': - return { - type: 'fetch', - urls: confirmationDetails.urls || [], - description: confirmationDetails.urls?.length - ? null - : confirmationDetails.prompt, - }; - default: { - const unreachable: never = confirmationDetails; - throw new Error(`Unexpected: ${unreachable}`); - } - } -} - -function toToolCallOutcome( - outcome: acp.ToolCallConfirmationOutcome, -): ToolConfirmationOutcome { - switch (outcome) { - case 'allow': - return ToolConfirmationOutcome.ProceedOnce; - case 'alwaysAllow': - return ToolConfirmationOutcome.ProceedAlways; - case 'alwaysAllowMcpServer': - return ToolConfirmationOutcome.ProceedAlwaysServer; - case 'alwaysAllowTool': - return ToolConfirmationOutcome.ProceedAlwaysTool; - case 'reject': - case 'cancel': - return ToolConfirmationOutcome.Cancel; + return [ + { + optionId: ToolConfirmationOutcome.ProceedAlways, + name: `Always Allow`, + kind: 'allow_always', + }, + ...basicPermissionOptions, + ]; default: { - const unreachable: never = outcome; + const unreachable: never = confirmation; throw new Error(`Unexpected: ${unreachable}`); } } |
