diff options
Diffstat (limited to 'packages/core/src/tools/mcp-client.ts')
| -rw-r--r-- | packages/core/src/tools/mcp-client.ts | 130 |
1 files changed, 129 insertions, 1 deletions
diff --git a/packages/core/src/tools/mcp-client.ts b/packages/core/src/tools/mcp-client.ts index e9001466..ede0d036 100644 --- a/packages/core/src/tools/mcp-client.ts +++ b/packages/core/src/tools/mcp-client.ts @@ -70,6 +70,134 @@ export enum MCPDiscoveryState { } /** + * A client for a single MCP server. + * + * This class is responsible for connecting to, discovering tools from, and + * managing the state of a single MCP server. + */ +export class McpClient { + private client: Client; + private transport: Transport | undefined; + private status: MCPServerStatus = MCPServerStatus.DISCONNECTED; + private isDisconnecting = false; + + constructor( + private readonly serverName: string, + private readonly serverConfig: MCPServerConfig, + private readonly toolRegistry: ToolRegistry, + private readonly promptRegistry: PromptRegistry, + private readonly workspaceContext: WorkspaceContext, + private readonly debugMode: boolean, + ) { + this.client = new Client({ + name: `gemini-cli-mcp-client-${this.serverName}`, + version: '0.0.1', + }); + } + + /** + * Connects to the MCP server. + */ + async connect(): Promise<void> { + this.isDisconnecting = false; + this.updateStatus(MCPServerStatus.CONNECTING); + try { + this.transport = await this.createTransport(); + + this.client.onerror = (error) => { + if (this.isDisconnecting) { + return; + } + console.error(`MCP ERROR (${this.serverName}):`, error.toString()); + this.updateStatus(MCPServerStatus.DISCONNECTED); + }; + + this.client.registerCapabilities({ + roots: {}, + }); + + this.client.setRequestHandler(ListRootsRequestSchema, async () => { + const roots = []; + for (const dir of this.workspaceContext.getDirectories()) { + roots.push({ + uri: pathToFileURL(dir).toString(), + name: basename(dir), + }); + } + return { + roots, + }; + }); + + await this.client.connect(this.transport, { + timeout: this.serverConfig.timeout, + }); + + this.updateStatus(MCPServerStatus.CONNECTED); + } catch (error) { + this.updateStatus(MCPServerStatus.DISCONNECTED); + throw error; + } + } + + /** + * Discovers tools and prompts from the MCP server. + */ + async discover(): Promise<void> { + if (this.status !== MCPServerStatus.CONNECTED) { + throw new Error('Client is not connected.'); + } + + const prompts = await this.discoverPrompts(); + const tools = await this.discoverTools(); + + if (prompts.length === 0 && tools.length === 0) { + throw new Error('No prompts or tools found on the server.'); + } + + for (const tool of tools) { + this.toolRegistry.registerTool(tool); + } + } + + /** + * Disconnects from the MCP server. + */ + async disconnect(): Promise<void> { + this.isDisconnecting = true; + if (this.transport) { + await this.transport.close(); + } + this.client.close(); + this.updateStatus(MCPServerStatus.DISCONNECTED); + } + + /** + * Returns the current status of the client. + */ + getStatus(): MCPServerStatus { + return this.status; + } + + private updateStatus(status: MCPServerStatus): void { + this.status = status; + updateMCPServerStatus(this.serverName, status); + } + + private async createTransport(): Promise<Transport> { + return createTransport(this.serverName, this.serverConfig, this.debugMode); + } + + private async discoverTools(): Promise<DiscoveredMCPTool[]> { + return discoverTools(this.serverName, this.serverConfig, this.client); + } + + private async discoverPrompts(): Promise<Prompt[]> { + return discoverPrompts(this.serverName, this.client, this.promptRegistry); + } +} + +/** * Map to track the status of each MCP server within the core package */ const serverStatuses: Map<string, MCPServerStatus> = new Map(); @@ -117,7 +245,7 @@ export function removeMCPStatusChangeListener( /** * Update the status of an MCP server */ -function updateMCPServerStatus( +export function updateMCPServerStatus( serverName: string, status: MCPServerStatus, ): void { |
