diff options
Diffstat (limited to 'packages/cli/src/zed-integration')
| -rw-r--r-- | packages/cli/src/zed-integration/acp.ts | 366 | ||||
| -rw-r--r-- | packages/cli/src/zed-integration/schema.ts | 457 | ||||
| -rw-r--r-- | packages/cli/src/zed-integration/zedIntegration.ts | 849 |
3 files changed, 1672 insertions, 0 deletions
diff --git a/packages/cli/src/zed-integration/acp.ts b/packages/cli/src/zed-integration/acp.ts new file mode 100644 index 00000000..eef4e1ee --- /dev/null +++ b/packages/cli/src/zed-integration/acp.ts @@ -0,0 +1,366 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +/* ACP defines a schema for a simple (experimental) JSON-RPC protocol that allows GUI applications to interact with agents. */ + +import { z } from 'zod'; +import * as schema from './schema.js'; +export * from './schema.js'; + +import { WritableStream, ReadableStream } from 'node:stream/web'; + +export class AgentSideConnection implements Client { + #connection: Connection; + + constructor( + toAgent: (conn: Client) => Agent, + input: WritableStream<Uint8Array>, + output: ReadableStream<Uint8Array>, + ) { + const agent = toAgent(this); + + const handler = async ( + method: string, + params: unknown, + ): Promise<unknown> => { + switch (method) { + case schema.AGENT_METHODS.initialize: { + const validatedParams = schema.initializeRequestSchema.parse(params); + return agent.initialize(validatedParams); + } + case schema.AGENT_METHODS.session_new: { + const validatedParams = schema.newSessionRequestSchema.parse(params); + return agent.newSession(validatedParams); + } + case schema.AGENT_METHODS.session_load: { + if (!agent.loadSession) { + throw RequestError.methodNotFound(); + } + const validatedParams = schema.loadSessionRequestSchema.parse(params); + return agent.loadSession(validatedParams); + } + case schema.AGENT_METHODS.authenticate: { + const validatedParams = + schema.authenticateRequestSchema.parse(params); + return agent.authenticate(validatedParams); + } + case schema.AGENT_METHODS.session_prompt: { + const validatedParams = schema.promptRequestSchema.parse(params); + return agent.prompt(validatedParams); + } + case schema.AGENT_METHODS.session_cancel: { + const validatedParams = schema.cancelNotificationSchema.parse(params); + return agent.cancel(validatedParams); + } + default: + throw RequestError.methodNotFound(method); + } + }; + + this.#connection = new Connection(handler, input, output); + } + + /** + * Streams new content to the client including text, tool calls, etc. + */ + async sessionUpdate(params: schema.SessionNotification): Promise<void> { + return await this.#connection.sendNotification( + schema.CLIENT_METHODS.session_update, + params, + ); + } + + /** + * Request permission before running a tool + * + * The agent specifies a series of permission options with different granularity, + * and the client returns the chosen one. + */ + async requestPermission( + params: schema.RequestPermissionRequest, + ): Promise<schema.RequestPermissionResponse> { + return await this.#connection.sendRequest( + schema.CLIENT_METHODS.session_request_permission, + params, + ); + } + + async readTextFile( + params: schema.ReadTextFileRequest, + ): Promise<schema.ReadTextFileResponse> { + return await this.#connection.sendRequest( + schema.CLIENT_METHODS.fs_read_text_file, + params, + ); + } + + async writeTextFile( + params: schema.WriteTextFileRequest, + ): Promise<schema.WriteTextFileResponse> { + return await this.#connection.sendRequest( + schema.CLIENT_METHODS.fs_write_text_file, + params, + ); + } +} + +type AnyMessage = AnyRequest | AnyResponse | AnyNotification; + +type AnyRequest = { + jsonrpc: '2.0'; + id: string | number; + method: string; + params?: unknown; +}; + +type AnyResponse = { + jsonrpc: '2.0'; + id: string | number; +} & Result<unknown>; + +type AnyNotification = { + jsonrpc: '2.0'; + method: string; + params?: unknown; +}; + +type Result<T> = + | { + result: T; + } + | { + error: ErrorResponse; + }; + +type ErrorResponse = { + code: number; + message: string; + data?: unknown; +}; + +type PendingResponse = { + resolve: (response: unknown) => void; + reject: (error: ErrorResponse) => void; +}; + +type MethodHandler = (method: string, params: unknown) => Promise<unknown>; + +class Connection { + #pendingResponses: Map<string | number, PendingResponse> = new Map(); + #nextRequestId: number = 0; + #handler: MethodHandler; + #peerInput: WritableStream<Uint8Array>; + #writeQueue: Promise<void> = Promise.resolve(); + #textEncoder: TextEncoder; + + constructor( + handler: MethodHandler, + peerInput: WritableStream<Uint8Array>, + peerOutput: ReadableStream<Uint8Array>, + ) { + this.#handler = handler; + this.#peerInput = peerInput; + this.#textEncoder = new TextEncoder(); + this.#receive(peerOutput); + } + + async #receive(output: ReadableStream<Uint8Array>) { + let content = ''; + const decoder = new TextDecoder(); + for await (const chunk of output) { + content += decoder.decode(chunk, { stream: true }); + const lines = content.split('\n'); + content = lines.pop() || ''; + + for (const line of lines) { + const trimmedLine = line.trim(); + + if (trimmedLine) { + const message = JSON.parse(trimmedLine); + this.#processMessage(message); + } + } + } + } + + async #processMessage(message: AnyMessage) { + if ('method' in message && 'id' in message) { + // It's a request + const response = await this.#tryCallHandler( + message.method, + message.params, + ); + + await this.#sendMessage({ + jsonrpc: '2.0', + id: message.id, + ...response, + }); + } else if ('method' in message) { + // It's a notification + await this.#tryCallHandler(message.method, message.params); + } else if ('id' in message) { + // It's a response + this.#handleResponse(message as AnyResponse); + } + } + + async #tryCallHandler( + method: string, + params?: unknown, + ): Promise<Result<unknown>> { + try { + const result = await this.#handler(method, params); + return { result: result ?? null }; + } catch (error: unknown) { + if (error instanceof RequestError) { + return error.toResult(); + } + + if (error instanceof z.ZodError) { + return RequestError.invalidParams( + JSON.stringify(error.format(), undefined, 2), + ).toResult(); + } + + let details; + + if (error instanceof Error) { + details = error.message; + } else if ( + typeof error === 'object' && + error != null && + 'message' in error && + typeof error.message === 'string' + ) { + details = error.message; + } + + return RequestError.internalError(details).toResult(); + } + } + + #handleResponse(response: AnyResponse) { + const pendingResponse = this.#pendingResponses.get(response.id); + if (pendingResponse) { + if ('result' in response) { + pendingResponse.resolve(response.result); + } else if ('error' in response) { + pendingResponse.reject(response.error); + } + this.#pendingResponses.delete(response.id); + } + } + + async sendRequest<Req, Resp>(method: string, params?: Req): Promise<Resp> { + const id = this.#nextRequestId++; + const responsePromise = new Promise((resolve, reject) => { + this.#pendingResponses.set(id, { resolve, reject }); + }); + await this.#sendMessage({ jsonrpc: '2.0', id, method, params }); + return responsePromise as Promise<Resp>; + } + + async sendNotification<N>(method: string, params?: N): Promise<void> { + await this.#sendMessage({ jsonrpc: '2.0', method, params }); + } + + async #sendMessage(json: AnyMessage) { + const content = JSON.stringify(json) + '\n'; + this.#writeQueue = this.#writeQueue + .then(async () => { + const writer = this.#peerInput.getWriter(); + try { + await writer.write(this.#textEncoder.encode(content)); + } finally { + writer.releaseLock(); + } + }) + .catch((error) => { + // Continue processing writes on error + console.error('ACP write error:', error); + }); + return this.#writeQueue; + } +} + +export class RequestError extends Error { + data?: { details?: string }; + + constructor( + public code: number, + message: string, + details?: string, + ) { + super(message); + this.name = 'RequestError'; + if (details) { + this.data = { details }; + } + } + + static parseError(details?: string): RequestError { + return new RequestError(-32700, 'Parse error', details); + } + + static invalidRequest(details?: string): RequestError { + return new RequestError(-32600, 'Invalid request', details); + } + + static methodNotFound(details?: string): RequestError { + return new RequestError(-32601, 'Method not found', details); + } + + static invalidParams(details?: string): RequestError { + return new RequestError(-32602, 'Invalid params', details); + } + + static internalError(details?: string): RequestError { + return new RequestError(-32603, 'Internal error', details); + } + + static authRequired(details?: string): RequestError { + return new RequestError(-32000, 'Authentication required', details); + } + + toResult<T>(): Result<T> { + return { + error: { + code: this.code, + message: this.message, + data: this.data, + }, + }; + } +} + +export interface Client { + requestPermission( + params: schema.RequestPermissionRequest, + ): Promise<schema.RequestPermissionResponse>; + sessionUpdate(params: schema.SessionNotification): Promise<void>; + writeTextFile( + params: schema.WriteTextFileRequest, + ): Promise<schema.WriteTextFileResponse>; + readTextFile( + params: schema.ReadTextFileRequest, + ): Promise<schema.ReadTextFileResponse>; +} + +export interface Agent { + initialize( + params: schema.InitializeRequest, + ): Promise<schema.InitializeResponse>; + newSession( + params: schema.NewSessionRequest, + ): Promise<schema.NewSessionResponse>; + loadSession?( + params: schema.LoadSessionRequest, + ): Promise<schema.LoadSessionResponse>; + authenticate(params: schema.AuthenticateRequest): Promise<void>; + prompt(params: schema.PromptRequest): Promise<schema.PromptResponse>; + cancel(params: schema.CancelNotification): Promise<void>; +} diff --git a/packages/cli/src/zed-integration/schema.ts b/packages/cli/src/zed-integration/schema.ts new file mode 100644 index 00000000..4c962131 --- /dev/null +++ b/packages/cli/src/zed-integration/schema.ts @@ -0,0 +1,457 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { z } from 'zod'; + +export const AGENT_METHODS = { + authenticate: 'authenticate', + initialize: 'initialize', + session_cancel: 'session/cancel', + session_load: 'session/load', + session_new: 'session/new', + session_prompt: 'session/prompt', +}; + +export const CLIENT_METHODS = { + fs_read_text_file: 'fs/read_text_file', + fs_write_text_file: 'fs/write_text_file', + session_request_permission: 'session/request_permission', + session_update: 'session/update', +}; + +export const PROTOCOL_VERSION = 1; + +export type WriteTextFileRequest = z.infer<typeof writeTextFileRequestSchema>; + +export type ReadTextFileRequest = z.infer<typeof readTextFileRequestSchema>; + +export type PermissionOptionKind = z.infer<typeof permissionOptionKindSchema>; + +export type Role = z.infer<typeof roleSchema>; + +export type TextResourceContents = z.infer<typeof textResourceContentsSchema>; + +export type BlobResourceContents = z.infer<typeof blobResourceContentsSchema>; + +export type ToolKind = z.infer<typeof toolKindSchema>; + +export type ToolCallStatus = z.infer<typeof toolCallStatusSchema>; + +export type WriteTextFileResponse = z.infer<typeof writeTextFileResponseSchema>; + +export type ReadTextFileResponse = z.infer<typeof readTextFileResponseSchema>; + +export type RequestPermissionOutcome = z.infer< + typeof requestPermissionOutcomeSchema +>; + +export type CancelNotification = z.infer<typeof cancelNotificationSchema>; + +export type AuthenticateRequest = z.infer<typeof authenticateRequestSchema>; + +export type AuthenticateResponse = z.infer<typeof authenticateResponseSchema>; + +export type NewSessionResponse = z.infer<typeof newSessionResponseSchema>; + +export type LoadSessionResponse = z.infer<typeof loadSessionResponseSchema>; + +export type StopReason = z.infer<typeof stopReasonSchema>; + +export type PromptResponse = z.infer<typeof promptResponseSchema>; + +export type ToolCallLocation = z.infer<typeof toolCallLocationSchema>; + +export type PlanEntry = z.infer<typeof planEntrySchema>; + +export type PermissionOption = z.infer<typeof permissionOptionSchema>; + +export type Annotations = z.infer<typeof annotationsSchema>; + +export type RequestPermissionResponse = z.infer< + typeof requestPermissionResponseSchema +>; + +export type FileSystemCapability = z.infer<typeof fileSystemCapabilitySchema>; + +export type EnvVariable = z.infer<typeof envVariableSchema>; + +export type McpServer = z.infer<typeof mcpServerSchema>; + +export type AgentCapabilities = z.infer<typeof agentCapabilitiesSchema>; + +export type AuthMethod = z.infer<typeof authMethodSchema>; + +export type ClientResponse = z.infer<typeof clientResponseSchema>; + +export type ClientNotification = z.infer<typeof clientNotificationSchema>; + +export type EmbeddedResourceResource = z.infer< + typeof embeddedResourceResourceSchema +>; + +export type NewSessionRequest = z.infer<typeof newSessionRequestSchema>; + +export type LoadSessionRequest = z.infer<typeof loadSessionRequestSchema>; + +export type InitializeResponse = z.infer<typeof initializeResponseSchema>; + +export type ContentBlock = z.infer<typeof contentBlockSchema>; + +export type ToolCallContent = z.infer<typeof toolCallContentSchema>; + +export type ToolCall = z.infer<typeof toolCallSchema>; + +export type ClientCapabilities = z.infer<typeof clientCapabilitiesSchema>; + +export type PromptRequest = z.infer<typeof promptRequestSchema>; + +export type SessionUpdate = z.infer<typeof sessionUpdateSchema>; + +export type AgentResponse = z.infer<typeof agentResponseSchema>; + +export type RequestPermissionRequest = z.infer< + typeof requestPermissionRequestSchema +>; + +export type InitializeRequest = z.infer<typeof initializeRequestSchema>; + +export type SessionNotification = z.infer<typeof sessionNotificationSchema>; + +export type ClientRequest = z.infer<typeof clientRequestSchema>; + +export type AgentRequest = z.infer<typeof agentRequestSchema>; + +export type AgentNotification = z.infer<typeof agentNotificationSchema>; + +export const writeTextFileRequestSchema = z.object({ + content: z.string(), + path: z.string(), + sessionId: z.string(), +}); + +export const readTextFileRequestSchema = z.object({ + limit: z.number().optional().nullable(), + line: z.number().optional().nullable(), + path: z.string(), + sessionId: z.string(), +}); + +export const permissionOptionKindSchema = z.union([ + z.literal('allow_once'), + z.literal('allow_always'), + z.literal('reject_once'), + z.literal('reject_always'), +]); + +export const roleSchema = z.union([z.literal('assistant'), z.literal('user')]); + +export const textResourceContentsSchema = z.object({ + mimeType: z.string().optional().nullable(), + text: z.string(), + uri: z.string(), +}); + +export const blobResourceContentsSchema = z.object({ + blob: z.string(), + mimeType: z.string().optional().nullable(), + uri: z.string(), +}); + +export const toolKindSchema = z.union([ + z.literal('read'), + z.literal('edit'), + z.literal('delete'), + z.literal('move'), + z.literal('search'), + z.literal('execute'), + z.literal('think'), + z.literal('fetch'), + z.literal('other'), +]); + +export const toolCallStatusSchema = z.union([ + z.literal('pending'), + z.literal('in_progress'), + z.literal('completed'), + z.literal('failed'), +]); + +export const writeTextFileResponseSchema = z.null(); + +export const readTextFileResponseSchema = z.object({ + content: z.string(), +}); + +export const requestPermissionOutcomeSchema = z.union([ + z.object({ + outcome: z.literal('cancelled'), + }), + z.object({ + optionId: z.string(), + outcome: z.literal('selected'), + }), +]); + +export const cancelNotificationSchema = z.object({ + sessionId: z.string(), +}); + +export const authenticateRequestSchema = z.object({ + methodId: z.string(), +}); + +export const authenticateResponseSchema = z.null(); + +export const newSessionResponseSchema = z.object({ + sessionId: z.string(), +}); + +export const loadSessionResponseSchema = z.null(); + +export const stopReasonSchema = z.union([ + z.literal('end_turn'), + z.literal('max_tokens'), + z.literal('refusal'), + z.literal('cancelled'), +]); + +export const promptResponseSchema = z.object({ + stopReason: stopReasonSchema, +}); + +export const toolCallLocationSchema = z.object({ + line: z.number().optional().nullable(), + path: z.string(), +}); + +export const planEntrySchema = z.object({ + content: z.string(), + priority: z.union([z.literal('high'), z.literal('medium'), z.literal('low')]), + status: z.union([ + z.literal('pending'), + z.literal('in_progress'), + z.literal('completed'), + ]), +}); + +export const permissionOptionSchema = z.object({ + kind: permissionOptionKindSchema, + name: z.string(), + optionId: z.string(), +}); + +export const annotationsSchema = z.object({ + audience: z.array(roleSchema).optional().nullable(), + lastModified: z.string().optional().nullable(), + priority: z.number().optional().nullable(), +}); + +export const requestPermissionResponseSchema = z.object({ + outcome: requestPermissionOutcomeSchema, +}); + +export const fileSystemCapabilitySchema = z.object({ + readTextFile: z.boolean(), + writeTextFile: z.boolean(), +}); + +export const envVariableSchema = z.object({ + name: z.string(), + value: z.string(), +}); + +export const mcpServerSchema = z.object({ + args: z.array(z.string()), + command: z.string(), + env: z.array(envVariableSchema), + name: z.string(), +}); + +export const agentCapabilitiesSchema = z.object({ + loadSession: z.boolean(), +}); + +export const authMethodSchema = z.object({ + description: z.string().nullable(), + id: z.string(), + name: z.string(), +}); + +export const clientResponseSchema = z.union([ + writeTextFileResponseSchema, + readTextFileResponseSchema, + requestPermissionResponseSchema, +]); + +export const clientNotificationSchema = cancelNotificationSchema; + +export const embeddedResourceResourceSchema = z.union([ + textResourceContentsSchema, + blobResourceContentsSchema, +]); + +export const newSessionRequestSchema = z.object({ + cwd: z.string(), + mcpServers: z.array(mcpServerSchema), +}); + +export const loadSessionRequestSchema = z.object({ + cwd: z.string(), + mcpServers: z.array(mcpServerSchema), + sessionId: z.string(), +}); + +export const initializeResponseSchema = z.object({ + agentCapabilities: agentCapabilitiesSchema, + authMethods: z.array(authMethodSchema), + protocolVersion: z.number(), +}); + +export const contentBlockSchema = z.union([ + z.object({ + annotations: annotationsSchema.optional().nullable(), + text: z.string(), + type: z.literal('text'), + }), + z.object({ + annotations: annotationsSchema.optional().nullable(), + data: z.string(), + mimeType: z.string(), + type: z.literal('image'), + }), + z.object({ + annotations: annotationsSchema.optional().nullable(), + data: z.string(), + mimeType: z.string(), + type: z.literal('audio'), + }), + z.object({ + annotations: annotationsSchema.optional().nullable(), + description: z.string().optional().nullable(), + mimeType: z.string().optional().nullable(), + name: z.string(), + size: z.number().optional().nullable(), + title: z.string().optional().nullable(), + type: z.literal('resource_link'), + uri: z.string(), + }), + z.object({ + annotations: annotationsSchema.optional().nullable(), + resource: embeddedResourceResourceSchema, + type: z.literal('resource'), + }), +]); + +export const toolCallContentSchema = z.union([ + z.object({ + content: contentBlockSchema, + type: z.literal('content'), + }), + z.object({ + newText: z.string(), + oldText: z.string().nullable(), + path: z.string(), + type: z.literal('diff'), + }), +]); + +export const toolCallSchema = z.object({ + content: z.array(toolCallContentSchema).optional(), + kind: toolKindSchema, + locations: z.array(toolCallLocationSchema).optional(), + rawInput: z.unknown().optional(), + status: toolCallStatusSchema, + title: z.string(), + toolCallId: z.string(), +}); + +export const clientCapabilitiesSchema = z.object({ + fs: fileSystemCapabilitySchema, +}); + +export const promptRequestSchema = z.object({ + prompt: z.array(contentBlockSchema), + sessionId: z.string(), +}); + +export const sessionUpdateSchema = z.union([ + z.object({ + content: contentBlockSchema, + sessionUpdate: z.literal('user_message_chunk'), + }), + z.object({ + content: contentBlockSchema, + sessionUpdate: z.literal('agent_message_chunk'), + }), + z.object({ + content: contentBlockSchema, + sessionUpdate: z.literal('agent_thought_chunk'), + }), + z.object({ + content: z.array(toolCallContentSchema).optional(), + kind: toolKindSchema, + locations: z.array(toolCallLocationSchema).optional(), + rawInput: z.unknown().optional(), + sessionUpdate: z.literal('tool_call'), + status: toolCallStatusSchema, + title: z.string(), + toolCallId: z.string(), + }), + z.object({ + content: z.array(toolCallContentSchema).optional().nullable(), + kind: toolKindSchema.optional().nullable(), + locations: z.array(toolCallLocationSchema).optional().nullable(), + rawInput: z.unknown().optional(), + sessionUpdate: z.literal('tool_call_update'), + status: toolCallStatusSchema.optional().nullable(), + title: z.string().optional().nullable(), + toolCallId: z.string(), + }), + z.object({ + entries: z.array(planEntrySchema), + sessionUpdate: z.literal('plan'), + }), +]); + +export const agentResponseSchema = z.union([ + initializeResponseSchema, + authenticateResponseSchema, + newSessionResponseSchema, + loadSessionResponseSchema, + promptResponseSchema, +]); + +export const requestPermissionRequestSchema = z.object({ + options: z.array(permissionOptionSchema), + sessionId: z.string(), + toolCall: toolCallSchema, +}); + +export const initializeRequestSchema = z.object({ + clientCapabilities: clientCapabilitiesSchema, + protocolVersion: z.number(), +}); + +export const sessionNotificationSchema = z.object({ + sessionId: z.string(), + update: sessionUpdateSchema, +}); + +export const clientRequestSchema = z.union([ + writeTextFileRequestSchema, + readTextFileRequestSchema, + requestPermissionRequestSchema, +]); + +export const agentRequestSchema = z.union([ + initializeRequestSchema, + authenticateRequestSchema, + newSessionRequestSchema, + loadSessionRequestSchema, + promptRequestSchema, +]); + +export const agentNotificationSchema = sessionNotificationSchema; diff --git a/packages/cli/src/zed-integration/zedIntegration.ts b/packages/cli/src/zed-integration/zedIntegration.ts new file mode 100644 index 00000000..1b5baa8a --- /dev/null +++ b/packages/cli/src/zed-integration/zedIntegration.ts @@ -0,0 +1,849 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { WritableStream, ReadableStream } from 'node:stream/web'; + +import { + AuthType, + Config, + GeminiChat, + ToolRegistry, + logToolCall, + ToolResult, + convertToFunctionResponse, + ToolCallConfirmationDetails, + ToolConfirmationOutcome, + clearCachedCredentialFile, + isNodeError, + getErrorMessage, + isWithinRoot, + getErrorStatus, + MCPServerConfig, +} from '@google/gemini-cli-core'; +import * as acp from './acp.js'; +import { Readable, Writable } from 'node:stream'; +import { Content, Part, FunctionCall, PartListUnion } from '@google/genai'; +import { LoadedSettings, SettingScope } from '../config/settings.js'; +import * as fs from 'fs/promises'; +import * as path from 'path'; +import { z } from 'zod'; + +import { randomUUID } from 'crypto'; +import { Extension } from '../config/extension.js'; +import { CliArgs, loadCliConfig } from '../config/config.js'; + +export async function runZedIntegration( + config: Config, + settings: LoadedSettings, + extensions: Extension[], + argv: CliArgs, +) { + const stdout = Writable.toWeb(process.stdout) as WritableStream; + const stdin = Readable.toWeb(process.stdin) as ReadableStream<Uint8Array>; + + // Stdout is used to send messages to the client, so console.log/console.info + // messages to stderr so that they don't interfere with ACP. + console.log = console.error; + console.info = console.error; + console.debug = console.error; + + new acp.AgentSideConnection( + (client: acp.Client) => + new GeminiAgent(config, settings, extensions, argv, client), + stdout, + stdin, + ); +} + +class GeminiAgent { + private sessions: Map<string, Session> = new Map(); + + constructor( + private config: Config, + private settings: LoadedSettings, + private extensions: Extension[], + private argv: CliArgs, + private client: acp.Client, + ) {} + + async initialize( + _args: acp.InitializeRequest, + ): Promise<acp.InitializeResponse> { + const authMethods = [ + { + id: AuthType.LOGIN_WITH_GOOGLE, + name: 'Log in with Google', + description: null, + }, + { + id: AuthType.USE_GEMINI, + name: 'Use Gemini API key', + description: + 'Requires setting the `GEMINI_API_KEY` environment variable', + }, + { + id: AuthType.USE_VERTEX_AI, + name: 'Vertex AI', + description: null, + }, + ]; + + return { + protocolVersion: acp.PROTOCOL_VERSION, + authMethods, + agentCapabilities: { + loadSession: false, + }, + }; + } + + async authenticate({ methodId }: acp.AuthenticateRequest): Promise<void> { + const method = z.nativeEnum(AuthType).parse(methodId); + + await clearCachedCredentialFile(); + await this.config.refreshAuth(method); + this.settings.setValue(SettingScope.User, 'selectedAuthType', method); + } + + async newSession({ + cwd, + mcpServers, + }: acp.NewSessionRequest): Promise<acp.NewSessionResponse> { + const sessionId = randomUUID(); + const config = await this.newSessionConfig(sessionId, cwd, mcpServers); + + let isAuthenticated = false; + if (this.settings.merged.selectedAuthType) { + try { + await config.refreshAuth(this.settings.merged.selectedAuthType); + isAuthenticated = true; + } catch (e) { + console.error(`Authentication failed: ${e}`); + } + } + + if (!isAuthenticated) { + throw acp.RequestError.authRequired(); + } + + const geminiClient = config.getGeminiClient(); + const chat = await geminiClient.startChat(); + const session = new Session(sessionId, chat, config, this.client); + this.sessions.set(sessionId, session); + + return { + sessionId, + }; + } + + async newSessionConfig( + sessionId: string, + cwd: string, + mcpServers: acp.McpServer[], + ): Promise<Config> { + const mergedMcpServers = { ...this.settings.merged.mcpServers }; + + for (const { command, args, env: rawEnv, name } of mcpServers) { + const env: Record<string, string> = {}; + for (const { name: envName, value } of rawEnv) { + env[envName] = value; + } + mergedMcpServers[name] = new MCPServerConfig(command, args, env, cwd); + } + + const settings = { ...this.settings.merged, mcpServers: mergedMcpServers }; + + const config = await loadCliConfig( + settings, + this.extensions, + sessionId, + this.argv, + cwd, + ); + + await config.initialize(); + return config; + } + + async cancel(params: acp.CancelNotification): Promise<void> { + const session = this.sessions.get(params.sessionId); + if (!session) { + throw new Error(`Session not found: ${params.sessionId}`); + } + await session.cancelPendingPrompt(); + } + + async prompt(params: acp.PromptRequest): Promise<acp.PromptResponse> { + const session = this.sessions.get(params.sessionId); + if (!session) { + throw new Error(`Session not found: ${params.sessionId}`); + } + return session.prompt(params); + } +} + +class Session { + private pendingPrompt: AbortController | null = null; + + constructor( + private readonly id: string, + private readonly chat: GeminiChat, + private readonly config: Config, + private readonly client: acp.Client, + ) {} + + async cancelPendingPrompt(): Promise<void> { + if (!this.pendingPrompt) { + throw new Error('Not currently generating'); + } + + this.pendingPrompt.abort(); + this.pendingPrompt = null; + } + + async prompt(params: acp.PromptRequest): Promise<acp.PromptResponse> { + this.pendingPrompt?.abort(); + const pendingSend = new AbortController(); + this.pendingPrompt = pendingSend; + + const promptId = Math.random().toString(16).slice(2); + const chat = this.chat; + + const parts = await this.#resolvePrompt(params.prompt, pendingSend.signal); + + let nextMessage: Content | null = { role: 'user', parts }; + + while (nextMessage !== null) { + if (pendingSend.signal.aborted) { + chat.addHistory(nextMessage); + return { stopReason: 'cancelled' }; + } + + const functionCalls: FunctionCall[] = []; + + try { + const responseStream = await chat.sendMessageStream( + { + message: nextMessage?.parts ?? [], + config: { + abortSignal: pendingSend.signal, + }, + }, + promptId, + ); + nextMessage = null; + + for await (const resp of responseStream) { + if (pendingSend.signal.aborted) { + return { stopReason: 'cancelled' }; + } + + if (resp.candidates && resp.candidates.length > 0) { + const candidate = resp.candidates[0]; + for (const part of candidate.content?.parts ?? []) { + if (!part.text) { + continue; + } + + const content: acp.ContentBlock = { + type: 'text', + text: part.text, + }; + + this.sendUpdate({ + sessionUpdate: part.thought + ? 'agent_thought_chunk' + : 'agent_message_chunk', + content, + }); + } + } + + if (resp.functionCalls) { + functionCalls.push(...resp.functionCalls); + } + } + } catch (error) { + if (getErrorStatus(error) === 429) { + throw new acp.RequestError( + 429, + 'Rate limit exceeded. Try again later.', + ); + } + + throw error; + } + + if (functionCalls.length > 0) { + const toolResponseParts: Part[] = []; + + for (const fc of functionCalls) { + const response = await this.runTool(pendingSend.signal, promptId, fc); + + const parts = Array.isArray(response) ? response : [response]; + + for (const part of parts) { + if (typeof part === 'string') { + toolResponseParts.push({ text: part }); + } else if (part) { + toolResponseParts.push(part); + } + } + } + + nextMessage = { role: 'user', parts: toolResponseParts }; + } + } + + return { stopReason: 'end_turn' }; + } + + private async sendUpdate(update: acp.SessionUpdate): Promise<void> { + const params: acp.SessionNotification = { + sessionId: this.id, + update, + }; + + await this.client.sessionUpdate(params); + } + + private async runTool( + abortSignal: AbortSignal, + promptId: string, + fc: FunctionCall, + ): Promise<PartListUnion> { + const callId = fc.id ?? `${fc.name}-${Date.now()}`; + const args = (fc.args ?? {}) as Record<string, unknown>; + + const startTime = Date.now(); + + const errorResponse = (error: Error) => { + const durationMs = Date.now() - startTime; + logToolCall(this.config, { + 'event.name': 'tool_call', + 'event.timestamp': new Date().toISOString(), + prompt_id: promptId, + function_name: fc.name ?? '', + function_args: args, + duration_ms: durationMs, + success: false, + error: error.message, + }); + + return [ + { + functionResponse: { + id: callId, + name: fc.name ?? '', + response: { error: error.message }, + }, + }, + ]; + }; + + if (!fc.name) { + return errorResponse(new Error('Missing function name')); + } + + const toolRegistry: ToolRegistry = await this.config.getToolRegistry(); + const tool = toolRegistry.getTool(fc.name as string); + + if (!tool) { + return errorResponse( + new Error(`Tool "${fc.name}" not found in registry.`), + ); + } + + const invocation = tool.build(args); + const confirmationDetails = + await invocation.shouldConfirmExecute(abortSignal); + + if (confirmationDetails) { + const content: acp.ToolCallContent[] = []; + + if (confirmationDetails.type === 'edit') { + content.push({ + type: 'diff', + path: confirmationDetails.fileName, + oldText: confirmationDetails.originalContent, + newText: confirmationDetails.newContent, + }); + } + + const params: acp.RequestPermissionRequest = { + sessionId: this.id, + options: toPermissionOptions(confirmationDetails), + toolCall: { + toolCallId: callId, + status: 'pending', + title: invocation.getDescription(), + content, + locations: invocation.toolLocations(), + kind: tool.kind, + }, + }; + + const output = await this.client.requestPermission(params); + const outcome = + output.outcome.outcome === 'cancelled' + ? ToolConfirmationOutcome.Cancel + : z + .nativeEnum(ToolConfirmationOutcome) + .parse(output.outcome.optionId); + + await confirmationDetails.onConfirm(outcome); + + switch (outcome) { + case ToolConfirmationOutcome.Cancel: + return errorResponse( + new Error(`Tool "${fc.name}" was canceled by the user.`), + ); + case ToolConfirmationOutcome.ProceedOnce: + case ToolConfirmationOutcome.ProceedAlways: + case ToolConfirmationOutcome.ProceedAlwaysServer: + case ToolConfirmationOutcome.ProceedAlwaysTool: + case ToolConfirmationOutcome.ModifyWithEditor: + break; + default: { + const resultOutcome: never = outcome; + throw new Error(`Unexpected: ${resultOutcome}`); + } + } + } else { + await this.sendUpdate({ + sessionUpdate: 'tool_call', + toolCallId: callId, + status: 'in_progress', + title: invocation.getDescription(), + content: [], + locations: invocation.toolLocations(), + kind: tool.kind, + }); + } + + try { + const toolResult: ToolResult = await invocation.execute(abortSignal); + const content = toToolCallContent(toolResult); + + await this.sendUpdate({ + sessionUpdate: 'tool_call_update', + toolCallId: callId, + status: 'completed', + content: content ? [content] : [], + }); + + const durationMs = Date.now() - startTime; + logToolCall(this.config, { + 'event.name': 'tool_call', + 'event.timestamp': new Date().toISOString(), + function_name: fc.name, + function_args: args, + duration_ms: durationMs, + success: true, + prompt_id: promptId, + }); + + return convertToFunctionResponse(fc.name, callId, toolResult.llmContent); + } catch (e) { + const error = e instanceof Error ? e : new Error(String(e)); + + await this.sendUpdate({ + sessionUpdate: 'tool_call_update', + toolCallId: callId, + status: 'failed', + content: [ + { type: 'content', content: { type: 'text', text: error.message } }, + ], + }); + + return errorResponse(error); + } + } + + async #resolvePrompt( + message: acp.ContentBlock[], + abortSignal: AbortSignal, + ): Promise<Part[]> { + const parts = message.map((part) => { + switch (part.type) { + case 'text': + return { text: part.text }; + case 'resource_link': + return { + fileData: { + mimeData: part.mimeType, + name: part.name, + fileUri: part.uri, + }, + }; + case 'resource': { + return { + fileData: { + mimeData: part.resource.mimeType, + name: part.resource.uri, + fileUri: part.resource.uri, + }, + }; + } + default: { + throw new Error(`Unexpected chunk type: '${part.type}'`); + } + } + }); + + const atPathCommandParts = parts.filter((part) => 'fileData' in part); + + if (atPathCommandParts.length === 0) { + return parts; + } + + // Get centralized file discovery service + const fileDiscovery = this.config.getFileService(); + const respectGitIgnore = this.config.getFileFilteringRespectGitIgnore(); + + const pathSpecsToRead: string[] = []; + const atPathToResolvedSpecMap = new Map<string, string>(); + const contentLabelsForDisplay: string[] = []; + const ignoredPaths: string[] = []; + + const toolRegistry = await this.config.getToolRegistry(); + const readManyFilesTool = toolRegistry.getTool('read_many_files'); + const globTool = toolRegistry.getTool('glob'); + + if (!readManyFilesTool) { + throw new Error('Error: read_many_files tool not found.'); + } + + for (const atPathPart of atPathCommandParts) { + const pathName = atPathPart.fileData!.fileUri; + // Check if path should be ignored by git + if (fileDiscovery.shouldGitIgnoreFile(pathName)) { + ignoredPaths.push(pathName); + const reason = respectGitIgnore + ? 'git-ignored and will be skipped' + : 'ignored by custom patterns'; + console.warn(`Path ${pathName} is ${reason}.`); + continue; + } + let currentPathSpec = pathName; + let resolvedSuccessfully = false; + try { + const absolutePath = path.resolve(this.config.getTargetDir(), pathName); + if (isWithinRoot(absolutePath, this.config.getTargetDir())) { + const stats = await fs.stat(absolutePath); + if (stats.isDirectory()) { + currentPathSpec = pathName.endsWith('/') + ? `${pathName}**` + : `${pathName}/**`; + this.debug( + `Path ${pathName} resolved to directory, using glob: ${currentPathSpec}`, + ); + } else { + this.debug(`Path ${pathName} resolved to file: ${currentPathSpec}`); + } + resolvedSuccessfully = true; + } else { + this.debug( + `Path ${pathName} is outside the project directory. Skipping.`, + ); + } + } catch (error) { + if (isNodeError(error) && error.code === 'ENOENT') { + if (this.config.getEnableRecursiveFileSearch() && globTool) { + this.debug( + `Path ${pathName} not found directly, attempting glob search.`, + ); + try { + const globResult = await globTool.buildAndExecute( + { + pattern: `**/*${pathName}*`, + path: this.config.getTargetDir(), + }, + abortSignal, + ); + if ( + globResult.llmContent && + typeof globResult.llmContent === 'string' && + !globResult.llmContent.startsWith('No files found') && + !globResult.llmContent.startsWith('Error:') + ) { + const lines = globResult.llmContent.split('\n'); + if (lines.length > 1 && lines[1]) { + const firstMatchAbsolute = lines[1].trim(); + currentPathSpec = path.relative( + this.config.getTargetDir(), + firstMatchAbsolute, + ); + this.debug( + `Glob search for ${pathName} found ${firstMatchAbsolute}, using relative path: ${currentPathSpec}`, + ); + resolvedSuccessfully = true; + } else { + this.debug( + `Glob search for '**/*${pathName}*' did not return a usable path. Path ${pathName} will be skipped.`, + ); + } + } else { + this.debug( + `Glob search for '**/*${pathName}*' found no files or an error. Path ${pathName} will be skipped.`, + ); + } + } catch (globError) { + console.error( + `Error during glob search for ${pathName}: ${getErrorMessage(globError)}`, + ); + } + } else { + this.debug( + `Glob tool not found. Path ${pathName} will be skipped.`, + ); + } + } else { + console.error( + `Error stating path ${pathName}. Path ${pathName} will be skipped.`, + ); + } + } + if (resolvedSuccessfully) { + pathSpecsToRead.push(currentPathSpec); + atPathToResolvedSpecMap.set(pathName, currentPathSpec); + contentLabelsForDisplay.push(pathName); + } + } + // Construct the initial part of the query for the LLM + let initialQueryText = ''; + for (let i = 0; i < parts.length; i++) { + const chunk = parts[i]; + if ('text' in chunk) { + initialQueryText += chunk.text; + } else { + // type === 'atPath' + const resolvedSpec = + chunk.fileData && atPathToResolvedSpecMap.get(chunk.fileData.fileUri); + if ( + i > 0 && + initialQueryText.length > 0 && + !initialQueryText.endsWith(' ') && + resolvedSpec + ) { + // Add space if previous part was text and didn't end with space, or if previous was @path + const prevPart = parts[i - 1]; + if ( + 'text' in prevPart || + ('fileData' in prevPart && + atPathToResolvedSpecMap.has(prevPart.fileData!.fileUri)) + ) { + initialQueryText += ' '; + } + } + if (resolvedSpec) { + initialQueryText += `@${resolvedSpec}`; + } else { + // If not resolved for reading (e.g. lone @ or invalid path that was skipped), + // add the original @-string back, ensuring spacing if it's not the first element. + if ( + i > 0 && + initialQueryText.length > 0 && + !initialQueryText.endsWith(' ') && + !chunk.fileData?.fileUri.startsWith(' ') + ) { + initialQueryText += ' '; + } + if (chunk.fileData?.fileUri) { + initialQueryText += `@${chunk.fileData.fileUri}`; + } + } + } + } + initialQueryText = initialQueryText.trim(); + // Inform user about ignored paths + if (ignoredPaths.length > 0) { + const ignoreType = respectGitIgnore ? 'git-ignored' : 'custom-ignored'; + this.debug( + `Ignored ${ignoredPaths.length} ${ignoreType} files: ${ignoredPaths.join(', ')}`, + ); + } + // Fallback for lone "@" or completely invalid @-commands resulting in empty initialQueryText + if (pathSpecsToRead.length === 0) { + console.warn('No valid file paths found in @ commands to read.'); + return [{ text: initialQueryText }]; + } + const processedQueryParts: Part[] = [{ text: initialQueryText }]; + const toolArgs = { + paths: pathSpecsToRead, + respectGitIgnore, // Use configuration setting + }; + + const callId = `${readManyFilesTool.name}-${Date.now()}`; + + try { + const invocation = readManyFilesTool.build(toolArgs); + + await this.sendUpdate({ + sessionUpdate: 'tool_call', + toolCallId: callId, + status: 'in_progress', + title: invocation.getDescription(), + content: [], + locations: invocation.toolLocations(), + kind: readManyFilesTool.kind, + }); + + const result = await invocation.execute(abortSignal); + const content = toToolCallContent(result) || { + type: 'content', + content: { + type: 'text', + text: `Successfully read: ${contentLabelsForDisplay.join(', ')}`, + }, + }; + await this.sendUpdate({ + sessionUpdate: 'tool_call_update', + toolCallId: callId, + status: 'completed', + content: content ? [content] : [], + }); + if (Array.isArray(result.llmContent)) { + const fileContentRegex = /^--- (.*?) ---\n\n([\s\S]*?)\n\n$/; + processedQueryParts.push({ + text: '\n--- Content from referenced files ---', + }); + for (const part of result.llmContent) { + if (typeof part === 'string') { + const match = fileContentRegex.exec(part); + if (match) { + const filePathSpecInContent = match[1]; // This is a resolved pathSpec + const fileActualContent = match[2].trim(); + processedQueryParts.push({ + text: `\nContent from @${filePathSpecInContent}:\n`, + }); + processedQueryParts.push({ text: fileActualContent }); + } else { + processedQueryParts.push({ text: part }); + } + } else { + // part is a Part object. + processedQueryParts.push(part); + } + } + processedQueryParts.push({ text: '\n--- End of content ---' }); + } else { + console.warn( + 'read_many_files tool returned no content or empty content.', + ); + } + return processedQueryParts; + } catch (error: unknown) { + await this.sendUpdate({ + sessionUpdate: 'tool_call_update', + toolCallId: callId, + status: 'failed', + content: [ + { + type: 'content', + content: { + type: 'text', + text: `Error reading files (${contentLabelsForDisplay.join(', ')}): ${getErrorMessage(error)}`, + }, + }, + ], + }); + + throw error; + } + } + + debug(msg: string) { + if (this.config.getDebugMode()) { + console.warn(msg); + } + } +} + +function toToolCallContent(toolResult: ToolResult): acp.ToolCallContent | null { + if (toolResult.returnDisplay) { + if (typeof toolResult.returnDisplay === 'string') { + return { + type: 'content', + content: { type: 'text', text: toolResult.returnDisplay }, + }; + } else { + return { + type: 'diff', + path: toolResult.returnDisplay.fileName, + oldText: toolResult.returnDisplay.originalContent, + newText: toolResult.returnDisplay.newContent, + }; + } + } else { + return null; + } +} + +const basicPermissionOptions = [ + { + optionId: ToolConfirmationOutcome.ProceedOnce, + name: 'Allow', + kind: 'allow_once', + }, + { + optionId: ToolConfirmationOutcome.Cancel, + name: 'Reject', + kind: 'reject_once', + }, +] as const; + +function toPermissionOptions( + confirmation: ToolCallConfirmationDetails, +): acp.PermissionOption[] { + switch (confirmation.type) { + case 'edit': + return [ + { + optionId: ToolConfirmationOutcome.ProceedAlways, + name: 'Allow All Edits', + kind: 'allow_always', + }, + ...basicPermissionOptions, + ]; + case 'exec': + return [ + { + optionId: ToolConfirmationOutcome.ProceedAlways, + name: `Always Allow ${confirmation.rootCommand}`, + kind: 'allow_always', + }, + ...basicPermissionOptions, + ]; + case 'mcp': + return [ + { + optionId: ToolConfirmationOutcome.ProceedAlwaysServer, + name: `Always Allow ${confirmation.serverName}`, + kind: 'allow_always', + }, + { + optionId: ToolConfirmationOutcome.ProceedAlwaysTool, + name: `Always Allow ${confirmation.toolName}`, + kind: 'allow_always', + }, + ...basicPermissionOptions, + ]; + case 'info': + return [ + { + optionId: ToolConfirmationOutcome.ProceedAlways, + name: `Always Allow`, + kind: 'allow_always', + }, + ...basicPermissionOptions, + ]; + default: { + const unreachable: never = confirmation; + throw new Error(`Unexpected: ${unreachable}`); + } + } +} |
