summaryrefslogtreecommitdiff
path: root/packages/core/src/tools/tool-registry.test.ts
diff options
context:
space:
mode:
Diffstat (limited to 'packages/core/src/tools/tool-registry.test.ts')
-rw-r--r--packages/core/src/tools/tool-registry.test.ts328
1 files changed, 227 insertions, 101 deletions
diff --git a/packages/core/src/tools/tool-registry.test.ts b/packages/core/src/tools/tool-registry.test.ts
index b39ec7b9..4d586d62 100644
--- a/packages/core/src/tools/tool-registry.test.ts
+++ b/packages/core/src/tools/tool-registry.test.ts
@@ -14,22 +14,22 @@ import {
afterEach,
Mocked,
} from 'vitest';
-import { ToolRegistry, DiscoveredTool } from './tool-registry.js';
-import { DiscoveredMCPTool } from './mcp-tool.js';
import {
- Config,
- ConfigParameters,
- MCPServerConfig,
- ApprovalMode,
-} from '../config/config.js';
+ ToolRegistry,
+ DiscoveredTool,
+ sanitizeParameters,
+} from './tool-registry.js';
+import { DiscoveredMCPTool } from './mcp-tool.js';
+import { Config, ConfigParameters, ApprovalMode } from '../config/config.js';
import { BaseTool, ToolResult } from './tools.js';
import {
FunctionDeclaration,
CallableTool,
mcpToTool,
Type,
+ Schema,
} from '@google/genai';
-import { execSync } from 'node:child_process';
+import { spawn } from 'node:child_process';
// Use vi.hoisted to define the mock function so it can be used in the vi.mock factory
const mockDiscoverMcpTools = vi.hoisted(() => vi.fn());
@@ -61,7 +61,6 @@ vi.mock('@modelcontextprotocol/sdk/client/index.js', () => {
set onerror(handler: any) {
mockMcpClientOnError(handler);
},
- // listTools and callTool are no longer directly used by ToolRegistry/discoverMcpTools
}));
return { Client: MockClient };
});
@@ -90,7 +89,6 @@ vi.mock('@google/genai', async () => {
return {
...actualGenai,
mcpToTool: vi.fn().mockImplementation(() => ({
- // Default mock implementation
tool: vi.fn().mockResolvedValue({ functionDeclarations: [] }),
callTool: vi.fn(),
})),
@@ -139,6 +137,7 @@ const baseConfigParams: ConfigParameters = {
describe('ToolRegistry', () => {
let config: Config;
let toolRegistry: ToolRegistry;
+ let mockConfigGetToolDiscoveryCommand: ReturnType<typeof vi.spyOn>;
beforeEach(() => {
config = new Config(baseConfigParams);
@@ -148,13 +147,19 @@ describe('ToolRegistry', () => {
vi.spyOn(console, 'debug').mockImplementation(() => {});
vi.spyOn(console, 'log').mockImplementation(() => {});
- // Reset mocks for MCP parts
- mockMcpClientConnect.mockReset().mockResolvedValue(undefined); // Default connect success
+ mockMcpClientConnect.mockReset().mockResolvedValue(undefined);
mockStdioTransportClose.mockReset();
mockSseTransportClose.mockReset();
vi.mocked(mcpToTool).mockClear();
- // Default mcpToTool to return a callable tool that returns no functions
vi.mocked(mcpToTool).mockReturnValue(createMockCallableTool([]));
+
+ mockConfigGetToolDiscoveryCommand = vi.spyOn(
+ config,
+ 'getToolDiscoveryCommand',
+ );
+ vi.spyOn(config, 'getMcpServers');
+ vi.spyOn(config, 'getMcpServerCommand');
+ mockDiscoverMcpTools.mockReset().mockResolvedValue(undefined);
});
afterEach(() => {
@@ -167,21 +172,18 @@ describe('ToolRegistry', () => {
toolRegistry.registerTool(tool);
expect(toolRegistry.getTool('mock-tool')).toBe(tool);
});
- // ... other registerTool tests
});
describe('getToolsByServer', () => {
it('should return an empty array if no tools match the server name', () => {
- toolRegistry.registerTool(new MockTool()); // A non-MCP tool
+ toolRegistry.registerTool(new MockTool());
expect(toolRegistry.getToolsByServer('any-mcp-server')).toEqual([]);
});
it('should return only tools matching the server name', async () => {
const server1Name = 'mcp-server-uno';
const server2Name = 'mcp-server-dos';
-
- // Manually register mock MCP tools for this test
- const mockCallable = {} as CallableTool; // Minimal mock callable
+ const mockCallable = {} as CallableTool;
const mcpTool1 = new DiscoveredMCPTool(
mockCallable,
server1Name,
@@ -207,73 +209,87 @@ describe('ToolRegistry', () => {
const toolsFromServer1 = toolRegistry.getToolsByServer(server1Name);
expect(toolsFromServer1).toHaveLength(1);
expect(toolsFromServer1[0].name).toBe(mcpTool1.name);
- expect((toolsFromServer1[0] as DiscoveredMCPTool).serverName).toBe(
- server1Name,
- );
const toolsFromServer2 = toolRegistry.getToolsByServer(server2Name);
expect(toolsFromServer2).toHaveLength(1);
expect(toolsFromServer2[0].name).toBe(mcpTool2.name);
- expect((toolsFromServer2[0] as DiscoveredMCPTool).serverName).toBe(
- server2Name,
- );
-
- expect(toolRegistry.getToolsByServer('non-existent-server')).toEqual([]);
});
});
describe('discoverTools', () => {
- let mockConfigGetToolDiscoveryCommand: ReturnType<typeof vi.spyOn>;
- let mockConfigGetMcpServers: ReturnType<typeof vi.spyOn>;
- let mockConfigGetMcpServerCommand: ReturnType<typeof vi.spyOn>;
- let mockExecSync: ReturnType<typeof vi.mocked<typeof execSync>>;
-
- beforeEach(() => {
- mockConfigGetToolDiscoveryCommand = vi.spyOn(
- config,
- 'getToolDiscoveryCommand',
- );
- mockConfigGetMcpServers = vi.spyOn(config, 'getMcpServers');
- mockConfigGetMcpServerCommand = vi.spyOn(config, 'getMcpServerCommand');
- mockExecSync = vi.mocked(execSync);
- toolRegistry = new ToolRegistry(config); // Reset registry
- // Reset the mock for discoverMcpTools before each test in this suite
- mockDiscoverMcpTools.mockReset().mockResolvedValue(undefined);
- });
-
- it('should discover tools using discovery command', async () => {
- // ... this test remains largely the same
+ it('should sanitize tool parameters during discovery from command', async () => {
const discoveryCommand = 'my-discovery-command';
mockConfigGetToolDiscoveryCommand.mockReturnValue(discoveryCommand);
- const mockToolDeclarations: FunctionDeclaration[] = [
- {
- name: 'discovered-tool-1',
- description: 'A discovered tool',
- parameters: { type: Type.OBJECT, properties: {} },
+
+ const unsanitizedToolDeclaration: FunctionDeclaration = {
+ name: 'tool-with-bad-format',
+ description: 'A tool with an invalid format property',
+ parameters: {
+ type: Type.OBJECT,
+ properties: {
+ some_string: {
+ type: Type.STRING,
+ format: 'uuid', // This is an unsupported format
+ },
+ },
},
- ];
- mockExecSync.mockReturnValue(
- Buffer.from(
- JSON.stringify([{ function_declarations: mockToolDeclarations }]),
- ),
- );
+ };
+
+ const mockSpawn = vi.mocked(spawn);
+ const mockChildProcess = {
+ stdout: { on: vi.fn() },
+ stderr: { on: vi.fn() },
+ on: vi.fn(),
+ };
+ mockSpawn.mockReturnValue(mockChildProcess as any);
+
+ // Simulate stdout data
+ mockChildProcess.stdout.on.mockImplementation((event, callback) => {
+ if (event === 'data') {
+ callback(
+ Buffer.from(
+ JSON.stringify([
+ { function_declarations: [unsanitizedToolDeclaration] },
+ ]),
+ ),
+ );
+ }
+ return mockChildProcess as any;
+ });
+
+ // Simulate process close
+ mockChildProcess.on.mockImplementation((event, callback) => {
+ if (event === 'close') {
+ callback(0);
+ }
+ return mockChildProcess as any;
+ });
+
await toolRegistry.discoverTools();
- expect(execSync).toHaveBeenCalledWith(discoveryCommand);
- const discoveredTool = toolRegistry.getTool('discovered-tool-1');
- expect(discoveredTool).toBeInstanceOf(DiscoveredTool);
+
+ const discoveredTool = toolRegistry.getTool('tool-with-bad-format');
+ expect(discoveredTool).toBeDefined();
+
+ const registeredParams = (discoveredTool as DiscoveredTool).schema
+ .parameters as Schema;
+ expect(registeredParams.properties?.['some_string']).toBeDefined();
+ expect(registeredParams.properties?.['some_string']).toHaveProperty(
+ 'format',
+ undefined,
+ );
});
it('should discover tools using MCP servers defined in getMcpServers', async () => {
mockConfigGetToolDiscoveryCommand.mockReturnValue(undefined);
- mockConfigGetMcpServerCommand.mockReturnValue(undefined);
+ vi.spyOn(config, 'getMcpServerCommand').mockReturnValue(undefined);
const mcpServerConfigVal = {
'my-mcp-server': {
command: 'mcp-server-cmd',
args: ['--port', '1234'],
trust: true,
- } as MCPServerConfig,
+ },
};
- mockConfigGetMcpServers.mockReturnValue(mcpServerConfigVal);
+ vi.spyOn(config, 'getMcpServers').mockReturnValue(mcpServerConfigVal);
await toolRegistry.discoverTools();
@@ -282,56 +298,166 @@ describe('ToolRegistry', () => {
undefined,
toolRegistry,
);
- // We no longer check these as discoverMcpTools is mocked
- // expect(vi.mocked(mcpToTool)).toHaveBeenCalledTimes(1);
- // expect(Client).toHaveBeenCalledTimes(1);
- // expect(StdioClientTransport).toHaveBeenCalledWith({
- // command: 'mcp-server-cmd',
- // args: ['--port', '1234'],
- // env: expect.any(Object),
- // stderr: 'pipe',
- // });
- // expect(mockMcpClientConnect).toHaveBeenCalled();
-
- // To verify that tools *would* have been registered, we'd need mockDiscoverMcpTools
- // to call toolRegistry.registerTool, or we test that separately.
- // For now, we just check that the delegation happened.
});
- it('should discover tools using MCP server command from getMcpServerCommand', async () => {
+ it('should discover tools using MCP servers defined in getMcpServers', async () => {
mockConfigGetToolDiscoveryCommand.mockReturnValue(undefined);
- mockConfigGetMcpServers.mockReturnValue({});
- mockConfigGetMcpServerCommand.mockReturnValue(
- 'mcp-server-start-command --param',
- );
+ vi.spyOn(config, 'getMcpServerCommand').mockReturnValue(undefined);
+ const mcpServerConfigVal = {
+ 'my-mcp-server': {
+ command: 'mcp-server-cmd',
+ args: ['--port', '1234'],
+ trust: true,
+ },
+ };
+ vi.spyOn(config, 'getMcpServers').mockReturnValue(mcpServerConfigVal);
await toolRegistry.discoverTools();
+
expect(mockDiscoverMcpTools).toHaveBeenCalledWith(
- {},
- 'mcp-server-start-command --param',
+ mcpServerConfigVal,
+ undefined,
toolRegistry,
);
});
+ });
+});
- it('should handle errors during MCP client connection gracefully and close transport', async () => {
- mockConfigGetToolDiscoveryCommand.mockReturnValue(undefined);
- mockConfigGetMcpServers.mockReturnValue({
- 'failing-mcp': { command: 'fail-cmd' } as MCPServerConfig,
- });
+describe('sanitizeParameters', () => {
+ it('should remove unsupported format from a simple string property', () => {
+ const schema: Schema = {
+ type: Type.OBJECT,
+ properties: {
+ name: { type: Type.STRING },
+ id: { type: Type.STRING, format: 'uuid' },
+ },
+ };
+ sanitizeParameters(schema);
+ expect(schema.properties?.['id']).toHaveProperty('format', undefined);
+ expect(schema.properties?.['name']).not.toHaveProperty('format');
+ });
- mockMcpClientConnect.mockRejectedValue(new Error('Connection failed'));
+ it('should NOT remove supported format values', () => {
+ const schema: Schema = {
+ type: Type.OBJECT,
+ properties: {
+ date: { type: Type.STRING, format: 'date-time' },
+ role: {
+ type: Type.STRING,
+ format: 'enum',
+ enum: ['admin', 'user'],
+ },
+ },
+ };
+ const originalSchema = JSON.parse(JSON.stringify(schema));
+ sanitizeParameters(schema);
+ expect(schema).toEqual(originalSchema);
+ });
- await toolRegistry.discoverTools();
- expect(mockDiscoverMcpTools).toHaveBeenCalledWith(
- {
- 'failing-mcp': { command: 'fail-cmd' },
+ it('should handle nested objects recursively', () => {
+ const schema: Schema = {
+ type: Type.OBJECT,
+ properties: {
+ user: {
+ type: Type.OBJECT,
+ properties: {
+ email: { type: Type.STRING, format: 'email' },
+ },
},
- undefined,
- toolRegistry,
- );
- expect(toolRegistry.getAllTools()).toHaveLength(0);
- });
+ },
+ };
+ sanitizeParameters(schema);
+ expect(schema.properties?.['user']?.properties?.['email']).toHaveProperty(
+ 'format',
+ undefined,
+ );
+ });
+
+ it('should handle arrays of objects', () => {
+ const schema: Schema = {
+ type: Type.OBJECT,
+ properties: {
+ items: {
+ type: Type.ARRAY,
+ items: {
+ type: Type.OBJECT,
+ properties: {
+ itemId: { type: Type.STRING, format: 'uuid' },
+ },
+ },
+ },
+ },
+ };
+ sanitizeParameters(schema);
+ expect(
+ (schema.properties?.['items']?.items as Schema)?.properties?.['itemId'],
+ ).toHaveProperty('format', undefined);
+ });
+
+ it('should handle schemas with no properties to sanitize', () => {
+ const schema: Schema = {
+ type: Type.OBJECT,
+ properties: {
+ count: { type: Type.NUMBER },
+ isActive: { type: Type.BOOLEAN },
+ },
+ };
+ const originalSchema = JSON.parse(JSON.stringify(schema));
+ sanitizeParameters(schema);
+ expect(schema).toEqual(originalSchema);
+ });
+
+ it('should not crash on an empty or undefined schema', () => {
+ expect(() => sanitizeParameters({})).not.toThrow();
+ expect(() => sanitizeParameters(undefined)).not.toThrow();
+ });
+
+ it('should handle cyclic schemas without crashing', () => {
+ const schema: any = {
+ type: Type.OBJECT,
+ properties: {
+ name: { type: Type.STRING, format: 'hostname' },
+ },
+ };
+ schema.properties.self = schema;
+
+ expect(() => sanitizeParameters(schema)).not.toThrow();
+ expect(schema.properties.name).toHaveProperty('format', undefined);
+ });
+
+ it('should handle complex nested schemas with cycles', () => {
+ const userNode: any = {
+ type: Type.OBJECT,
+ properties: {
+ id: { type: Type.STRING, format: 'uuid' },
+ name: { type: Type.STRING },
+ manager: {
+ type: Type.OBJECT,
+ properties: {
+ id: { type: Type.STRING, format: 'uuid' },
+ },
+ },
+ },
+ };
+ userNode.properties.reports = {
+ type: Type.ARRAY,
+ items: userNode,
+ };
+
+ const schema: Schema = {
+ type: Type.OBJECT,
+ properties: {
+ ceo: userNode,
+ },
+ };
+
+ expect(() => sanitizeParameters(schema)).not.toThrow();
+ expect(schema.properties?.['ceo']?.properties?.['id']).toHaveProperty(
+ 'format',
+ undefined,
+ );
+ expect(
+ schema.properties?.['ceo']?.properties?.['manager']?.properties?.['id'],
+ ).toHaveProperty('format', undefined);
});
- // Other tests for DiscoveredTool and DiscoveredMCPTool can be simplified or removed
- // if their core logic is now tested in their respective dedicated test files (mcp-tool.test.ts)
});