summaryrefslogtreecommitdiff
path: root/packages/core/src/tools
diff options
context:
space:
mode:
Diffstat (limited to 'packages/core/src/tools')
-rw-r--r--packages/core/src/tools/mcp-client.test.ts70
-rw-r--r--packages/core/src/tools/mcp-client.ts20
-rw-r--r--packages/core/src/tools/tool-registry.test.ts20
-rw-r--r--packages/core/src/tools/tool-registry.ts6
4 files changed, 90 insertions, 26 deletions
diff --git a/packages/core/src/tools/mcp-client.test.ts b/packages/core/src/tools/mcp-client.test.ts
index 121cd1d8..abd9c58f 100644
--- a/packages/core/src/tools/mcp-client.test.ts
+++ b/packages/core/src/tools/mcp-client.test.ts
@@ -135,7 +135,11 @@ describe('discoverMcpTools', () => {
});
it('should do nothing if no MCP servers or command are configured', async () => {
- await discoverMcpTools(mockConfig);
+ await discoverMcpTools(
+ mockConfig.getMcpServers() ?? {},
+ mockConfig.getMcpServerCommand(),
+ mockToolRegistry as any,
+ );
expect(mockConfig.getMcpServers).toHaveBeenCalledTimes(1);
expect(mockConfig.getMcpServerCommand).toHaveBeenCalledTimes(1);
expect(Client).not.toHaveBeenCalled();
@@ -161,7 +165,11 @@ describe('discoverMcpTools', () => {
// In this case, listTools fails, so no tools are registered.
// The default mock `mockReturnValue([])` from beforeEach should apply.
- await discoverMcpTools(mockConfig);
+ await discoverMcpTools(
+ mockConfig.getMcpServers() ?? {},
+ mockConfig.getMcpServerCommand(),
+ mockToolRegistry as any,
+ );
expect(parse).toHaveBeenCalledWith(commandString, process.env);
expect(StdioClientTransport).toHaveBeenCalledWith({
@@ -204,7 +212,11 @@ describe('discoverMcpTools', () => {
expect.any(DiscoveredMCPTool),
]);
- await discoverMcpTools(mockConfig);
+ await discoverMcpTools(
+ mockConfig.getMcpServers() ?? {},
+ mockConfig.getMcpServerCommand(),
+ mockToolRegistry as any,
+ );
expect(StdioClientTransport).toHaveBeenCalledWith({
command: serverConfig.command,
@@ -239,7 +251,11 @@ describe('discoverMcpTools', () => {
expect.any(DiscoveredMCPTool),
]);
- await discoverMcpTools(mockConfig);
+ await discoverMcpTools(
+ mockConfig.getMcpServers() ?? {},
+ mockConfig.getMcpServerCommand(),
+ mockToolRegistry as any,
+ );
expect(SSEClientTransport).toHaveBeenCalledWith(new URL(serverConfig.url!));
expect(mockToolRegistry.registerTool).toHaveBeenCalledWith(
@@ -317,7 +333,11 @@ describe('discoverMcpTools', () => {
},
);
- await discoverMcpTools(mockConfig);
+ await discoverMcpTools(
+ mockConfig.getMcpServers() ?? {},
+ mockConfig.getMcpServerCommand(),
+ mockToolRegistry as any,
+ );
expect(mockToolRegistry.registerTool).toHaveBeenCalledTimes(3);
const registeredArgs = mockToolRegistry.registerTool.mock.calls.map(
@@ -381,7 +401,11 @@ describe('discoverMcpTools', () => {
expect.any(DiscoveredMCPTool),
]);
- await discoverMcpTools(mockConfig);
+ await discoverMcpTools(
+ mockConfig.getMcpServers() ?? {},
+ mockConfig.getMcpServerCommand(),
+ mockToolRegistry as any,
+ );
expect(mockToolRegistry.registerTool).toHaveBeenCalledTimes(1);
const registeredTool = mockToolRegistry.registerTool.mock
@@ -410,9 +434,13 @@ describe('discoverMcpTools', () => {
});
vi.spyOn(console, 'error').mockImplementation(() => {});
- await expect(discoverMcpTools(mockConfig)).rejects.toThrow(
- 'Parsing failed',
- );
+ await expect(
+ discoverMcpTools(
+ mockConfig.getMcpServers() ?? {},
+ mockConfig.getMcpServerCommand(),
+ mockToolRegistry as any,
+ ),
+ ).rejects.toThrow('Parsing failed');
expect(mockToolRegistry.registerTool).not.toHaveBeenCalled();
expect(console.error).not.toHaveBeenCalled();
});
@@ -421,7 +449,11 @@ describe('discoverMcpTools', () => {
mockConfig.getMcpServers.mockReturnValue({ 'bad-server': {} as any });
vi.spyOn(console, 'error').mockImplementation(() => {});
- await discoverMcpTools(mockConfig);
+ await discoverMcpTools(
+ mockConfig.getMcpServers() ?? {},
+ mockConfig.getMcpServerCommand(),
+ mockToolRegistry as any,
+ );
expect(console.error).toHaveBeenCalledWith(
expect.stringContaining(
@@ -442,7 +474,11 @@ describe('discoverMcpTools', () => {
);
vi.spyOn(console, 'error').mockImplementation(() => {});
- await discoverMcpTools(mockConfig);
+ await discoverMcpTools(
+ mockConfig.getMcpServers() ?? {},
+ mockConfig.getMcpServerCommand(),
+ mockToolRegistry as any,
+ );
expect(console.error).toHaveBeenCalledWith(
expect.stringContaining(
@@ -463,7 +499,11 @@ describe('discoverMcpTools', () => {
);
vi.spyOn(console, 'error').mockImplementation(() => {});
- await discoverMcpTools(mockConfig);
+ await discoverMcpTools(
+ mockConfig.getMcpServers() ?? {},
+ mockConfig.getMcpServerCommand(),
+ mockToolRegistry as any,
+ );
expect(console.error).toHaveBeenCalledWith(
expect.stringContaining(
@@ -483,7 +523,11 @@ describe('discoverMcpTools', () => {
expect.any(DiscoveredMCPTool),
]);
- await discoverMcpTools(mockConfig);
+ await discoverMcpTools(
+ mockConfig.getMcpServers() ?? {},
+ mockConfig.getMcpServerCommand(),
+ mockToolRegistry as any,
+ );
const clientInstances = vi.mocked(Client).mock.results;
expect(clientInstances.length).toBeGreaterThan(0);
diff --git a/packages/core/src/tools/mcp-client.ts b/packages/core/src/tools/mcp-client.ts
index 87835219..1b7823c7 100644
--- a/packages/core/src/tools/mcp-client.ts
+++ b/packages/core/src/tools/mcp-client.ts
@@ -8,15 +8,18 @@ import { Client } from '@modelcontextprotocol/sdk/client/index.js';
import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js';
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
import { parse } from 'shell-quote';
-import { Config, MCPServerConfig } from '../config/config.js';
+import { MCPServerConfig } from '../config/config.js';
import { DiscoveredMCPTool } from './mcp-tool.js';
import { CallableTool, FunctionDeclaration, mcpToTool } from '@google/genai';
+import { ToolRegistry } from './tool-registry.js';
-export async function discoverMcpTools(config: Config): Promise<void> {
- const mcpServers = config.getMcpServers() || {};
-
- if (config.getMcpServerCommand()) {
- const cmd = config.getMcpServerCommand()!;
+export async function discoverMcpTools(
+ mcpServers: Record<string, MCPServerConfig>,
+ mcpServerCommand: string | undefined,
+ toolRegistry: ToolRegistry,
+): Promise<void> {
+ if (mcpServerCommand) {
+ const cmd = mcpServerCommand;
const args = parse(cmd, process.env) as string[];
if (args.some((arg) => typeof arg !== 'string')) {
throw new Error('failed to parse mcpServerCommand: ' + cmd);
@@ -30,7 +33,7 @@ export async function discoverMcpTools(config: Config): Promise<void> {
const discoveryPromises = Object.entries(mcpServers).map(
([mcpServerName, mcpServerConfig]) =>
- connectAndDiscover(mcpServerName, mcpServerConfig, config),
+ connectAndDiscover(mcpServerName, mcpServerConfig, toolRegistry),
);
await Promise.all(discoveryPromises);
}
@@ -38,7 +41,7 @@ export async function discoverMcpTools(config: Config): Promise<void> {
async function connectAndDiscover(
mcpServerName: string,
mcpServerConfig: MCPServerConfig,
- config: Config,
+ toolRegistry: ToolRegistry,
): Promise<void> {
let transport;
if (mcpServerConfig.url) {
@@ -90,7 +93,6 @@ async function connectAndDiscover(
});
}
- const toolRegistry = await config.getToolRegistry();
try {
const mcpCallableTool: CallableTool = mcpToTool(mcpClient);
const discoveredToolFunctions = await mcpCallableTool.tool();
diff --git a/packages/core/src/tools/tool-registry.test.ts b/packages/core/src/tools/tool-registry.test.ts
index 1fb2df4e..f57f5bce 100644
--- a/packages/core/src/tools/tool-registry.test.ts
+++ b/packages/core/src/tools/tool-registry.test.ts
@@ -277,7 +277,11 @@ describe('ToolRegistry', () => {
await toolRegistry.discoverTools();
- expect(mockDiscoverMcpTools).toHaveBeenCalledWith(config);
+ expect(mockDiscoverMcpTools).toHaveBeenCalledWith(
+ mcpServerConfigVal,
+ undefined,
+ toolRegistry,
+ );
// We no longer check these as discoverMcpTools is mocked
// expect(vi.mocked(mcpToTool)).toHaveBeenCalledTimes(1);
// expect(Client).toHaveBeenCalledTimes(1);
@@ -302,7 +306,11 @@ describe('ToolRegistry', () => {
);
await toolRegistry.discoverTools();
- expect(mockDiscoverMcpTools).toHaveBeenCalledWith(config);
+ expect(mockDiscoverMcpTools).toHaveBeenCalledWith(
+ {},
+ 'mcp-server-start-command --param',
+ toolRegistry,
+ );
});
it('should handle errors during MCP client connection gracefully and close transport', async () => {
@@ -314,7 +322,13 @@ describe('ToolRegistry', () => {
mockMcpClientConnect.mockRejectedValue(new Error('Connection failed'));
await toolRegistry.discoverTools();
- expect(mockDiscoverMcpTools).toHaveBeenCalledWith(config);
+ expect(mockDiscoverMcpTools).toHaveBeenCalledWith(
+ {
+ 'failing-mcp': { command: 'fail-cmd' },
+ },
+ undefined,
+ toolRegistry,
+ );
expect(toolRegistry.getAllTools()).toHaveLength(0);
});
});
diff --git a/packages/core/src/tools/tool-registry.ts b/packages/core/src/tools/tool-registry.ts
index 12aa1a83..2b27a703 100644
--- a/packages/core/src/tools/tool-registry.ts
+++ b/packages/core/src/tools/tool-registry.ts
@@ -161,7 +161,11 @@ export class ToolRegistry {
}
}
// discover tools using MCP servers, if configured
- await discoverMcpTools(this.config);
+ await discoverMcpTools(
+ this.config.getMcpServers() ?? {},
+ this.config.getMcpServerCommand(),
+ this,
+ );
}
/**