summaryrefslogtreecommitdiff
path: root/packages/core/src/tools/mcp-client.ts
diff options
context:
space:
mode:
Diffstat (limited to 'packages/core/src/tools/mcp-client.ts')
-rw-r--r--packages/core/src/tools/mcp-client.ts130
1 files changed, 129 insertions, 1 deletions
diff --git a/packages/core/src/tools/mcp-client.ts b/packages/core/src/tools/mcp-client.ts
index e9001466..ede0d036 100644
--- a/packages/core/src/tools/mcp-client.ts
+++ b/packages/core/src/tools/mcp-client.ts
@@ -70,6 +70,134 @@ export enum MCPDiscoveryState {
}
/**
+ * A client for a single MCP server.
+ *
+ * This class is responsible for connecting to, discovering tools from, and
+ * managing the state of a single MCP server.
+ */
+export class McpClient {
+ private client: Client;
+ private transport: Transport | undefined;
+ private status: MCPServerStatus = MCPServerStatus.DISCONNECTED;
+ private isDisconnecting = false;
+
+ constructor(
+ private readonly serverName: string,
+ private readonly serverConfig: MCPServerConfig,
+ private readonly toolRegistry: ToolRegistry,
+ private readonly promptRegistry: PromptRegistry,
+ private readonly workspaceContext: WorkspaceContext,
+ private readonly debugMode: boolean,
+ ) {
+ this.client = new Client({
+ name: `gemini-cli-mcp-client-${this.serverName}`,
+ version: '0.0.1',
+ });
+ }
+
+ /**
+ * Connects to the MCP server.
+ */
+ async connect(): Promise<void> {
+ this.isDisconnecting = false;
+ this.updateStatus(MCPServerStatus.CONNECTING);
+ try {
+ this.transport = await this.createTransport();
+
+ this.client.onerror = (error) => {
+ if (this.isDisconnecting) {
+ return;
+ }
+ console.error(`MCP ERROR (${this.serverName}):`, error.toString());
+ this.updateStatus(MCPServerStatus.DISCONNECTED);
+ };
+
+ this.client.registerCapabilities({
+ roots: {},
+ });
+
+ this.client.setRequestHandler(ListRootsRequestSchema, async () => {
+ const roots = [];
+ for (const dir of this.workspaceContext.getDirectories()) {
+ roots.push({
+ uri: pathToFileURL(dir).toString(),
+ name: basename(dir),
+ });
+ }
+ return {
+ roots,
+ };
+ });
+
+ await this.client.connect(this.transport, {
+ timeout: this.serverConfig.timeout,
+ });
+
+ this.updateStatus(MCPServerStatus.CONNECTED);
+ } catch (error) {
+ this.updateStatus(MCPServerStatus.DISCONNECTED);
+ throw error;
+ }
+ }
+
+ /**
+ * Discovers tools and prompts from the MCP server.
+ */
+ async discover(): Promise<void> {
+ if (this.status !== MCPServerStatus.CONNECTED) {
+ throw new Error('Client is not connected.');
+ }
+
+ const prompts = await this.discoverPrompts();
+ const tools = await this.discoverTools();
+
+ if (prompts.length === 0 && tools.length === 0) {
+ throw new Error('No prompts or tools found on the server.');
+ }
+
+ for (const tool of tools) {
+ this.toolRegistry.registerTool(tool);
+ }
+ }
+
+ /**
+ * Disconnects from the MCP server.
+ */
+ async disconnect(): Promise<void> {
+ this.isDisconnecting = true;
+ if (this.transport) {
+ await this.transport.close();
+ }
+ this.client.close();
+ this.updateStatus(MCPServerStatus.DISCONNECTED);
+ }
+
+ /**
+ * Returns the current status of the client.
+ */
+ getStatus(): MCPServerStatus {
+ return this.status;
+ }
+
+ private updateStatus(status: MCPServerStatus): void {
+ this.status = status;
+ updateMCPServerStatus(this.serverName, status);
+ }
+
+ private async createTransport(): Promise<Transport> {
+ return createTransport(this.serverName, this.serverConfig, this.debugMode);
+ }
+
+ private async discoverTools(): Promise<DiscoveredMCPTool[]> {
+ return discoverTools(this.serverName, this.serverConfig, this.client);
+ }
+
+ private async discoverPrompts(): Promise<Prompt[]> {
+ return discoverPrompts(this.serverName, this.client, this.promptRegistry);
+ }
+}
+
+/**
* Map to track the status of each MCP server within the core package
*/
const serverStatuses: Map<string, MCPServerStatus> = new Map();
@@ -117,7 +245,7 @@ export function removeMCPStatusChangeListener(
/**
* Update the status of an MCP server
*/
-function updateMCPServerStatus(
+export function updateMCPServerStatus(
serverName: string,
status: MCPServerStatus,
): void {