summaryrefslogtreecommitdiff
path: root/packages/core/src
diff options
context:
space:
mode:
authorjoshualitt <[email protected]>2025-08-13 11:57:37 -0700
committerGitHub <[email protected]>2025-08-13 18:57:37 +0000
commit904f4623b6945345d5845649e98f554671b1edfb (patch)
tree57cad495ffe8973af4e02ed2fa2e8cc752905e0a /packages/core/src
parent22109db320e66dcdfa4aff87adaab626b6cf9b15 (diff)
feat(core): Continue declarative tool migration. (#6114)
Diffstat (limited to 'packages/core/src')
-rw-r--r--packages/core/src/tools/ls.test.ts142
-rw-r--r--packages/core/src/tools/ls.ts213
-rw-r--r--packages/core/src/tools/mcp-tool.test.ts242
-rw-r--r--packages/core/src/tools/mcp-tool.ts115
-rw-r--r--packages/core/src/tools/memoryTool.test.ts76
-rw-r--r--packages/core/src/tools/memoryTool.ts314
6 files changed, 514 insertions, 588 deletions
diff --git a/packages/core/src/tools/ls.test.ts b/packages/core/src/tools/ls.test.ts
index fb99d829..2fbeb37a 100644
--- a/packages/core/src/tools/ls.test.ts
+++ b/packages/core/src/tools/ls.test.ts
@@ -74,9 +74,11 @@ describe('LSTool', () => {
const params = {
path: '/home/user/project/src',
};
-
- const error = lsTool.validateToolParams(params);
- expect(error).toBeNull();
+ vi.mocked(fs.statSync).mockReturnValue({
+ isDirectory: () => true,
+ } as fs.Stats);
+ const invocation = lsTool.build(params);
+ expect(invocation).toBeDefined();
});
it('should reject relative paths', () => {
@@ -84,8 +86,9 @@ describe('LSTool', () => {
path: './src',
};
- const error = lsTool.validateToolParams(params);
- expect(error).toBe('Path must be absolute: ./src');
+ expect(() => lsTool.build(params)).toThrow(
+ 'Path must be absolute: ./src',
+ );
});
it('should reject paths outside workspace with clear error message', () => {
@@ -93,8 +96,7 @@ describe('LSTool', () => {
path: '/etc/passwd',
};
- const error = lsTool.validateToolParams(params);
- expect(error).toBe(
+ expect(() => lsTool.build(params)).toThrow(
'Path must be within one of the workspace directories: /home/user/project, /home/user/other-project',
);
});
@@ -103,9 +105,11 @@ describe('LSTool', () => {
const params = {
path: '/home/user/other-project/lib',
};
-
- const error = lsTool.validateToolParams(params);
- expect(error).toBeNull();
+ vi.mocked(fs.statSync).mockReturnValue({
+ isDirectory: () => true,
+ } as fs.Stats);
+ const invocation = lsTool.build(params);
+ expect(invocation).toBeDefined();
});
});
@@ -133,10 +137,8 @@ describe('LSTool', () => {
vi.mocked(fs.readdirSync).mockReturnValue(mockFiles as any);
- const result = await lsTool.execute(
- { path: testPath },
- new AbortController().signal,
- );
+ const invocation = lsTool.build({ path: testPath });
+ const result = await invocation.execute(new AbortController().signal);
expect(result.llmContent).toContain('[DIR] subdir');
expect(result.llmContent).toContain('file1.ts');
@@ -161,10 +163,8 @@ describe('LSTool', () => {
vi.mocked(fs.readdirSync).mockReturnValue(mockFiles as any);
- const result = await lsTool.execute(
- { path: testPath },
- new AbortController().signal,
- );
+ const invocation = lsTool.build({ path: testPath });
+ const result = await invocation.execute(new AbortController().signal);
expect(result.llmContent).toContain('module1.js');
expect(result.llmContent).toContain('module2.js');
@@ -179,10 +179,8 @@ describe('LSTool', () => {
} as fs.Stats);
vi.mocked(fs.readdirSync).mockReturnValue([]);
- const result = await lsTool.execute(
- { path: testPath },
- new AbortController().signal,
- );
+ const invocation = lsTool.build({ path: testPath });
+ const result = await invocation.execute(new AbortController().signal);
expect(result.llmContent).toBe(
'Directory /home/user/project/empty is empty.',
@@ -207,10 +205,11 @@ describe('LSTool', () => {
});
vi.mocked(fs.readdirSync).mockReturnValue(mockFiles as any);
- const result = await lsTool.execute(
- { path: testPath, ignore: ['*.spec.js'] },
- new AbortController().signal,
- );
+ const invocation = lsTool.build({
+ path: testPath,
+ ignore: ['*.spec.js'],
+ });
+ const result = await invocation.execute(new AbortController().signal);
expect(result.llmContent).toContain('test.js');
expect(result.llmContent).toContain('index.js');
@@ -238,10 +237,8 @@ describe('LSTool', () => {
(path: string) => path.includes('ignored.js'),
);
- const result = await lsTool.execute(
- { path: testPath },
- new AbortController().signal,
- );
+ const invocation = lsTool.build({ path: testPath });
+ const result = await invocation.execute(new AbortController().signal);
expect(result.llmContent).toContain('file1.js');
expect(result.llmContent).toContain('file2.js');
@@ -269,10 +266,8 @@ describe('LSTool', () => {
(path: string) => path.includes('private.js'),
);
- const result = await lsTool.execute(
- { path: testPath },
- new AbortController().signal,
- );
+ const invocation = lsTool.build({ path: testPath });
+ const result = await invocation.execute(new AbortController().signal);
expect(result.llmContent).toContain('file1.js');
expect(result.llmContent).toContain('file2.js');
@@ -287,10 +282,8 @@ describe('LSTool', () => {
isDirectory: () => false,
} as fs.Stats);
- const result = await lsTool.execute(
- { path: testPath },
- new AbortController().signal,
- );
+ const invocation = lsTool.build({ path: testPath });
+ const result = await invocation.execute(new AbortController().signal);
expect(result.llmContent).toContain('Path is not a directory');
expect(result.returnDisplay).toBe('Error: Path is not a directory.');
@@ -303,10 +296,8 @@ describe('LSTool', () => {
throw new Error('ENOENT: no such file or directory');
});
- const result = await lsTool.execute(
- { path: testPath },
- new AbortController().signal,
- );
+ const invocation = lsTool.build({ path: testPath });
+ const result = await invocation.execute(new AbortController().signal);
expect(result.llmContent).toContain('Error listing directory');
expect(result.returnDisplay).toBe('Error: Failed to list directory.');
@@ -336,10 +327,8 @@ describe('LSTool', () => {
vi.mocked(fs.readdirSync).mockReturnValue(mockFiles as any);
- const result = await lsTool.execute(
- { path: testPath },
- new AbortController().signal,
- );
+ const invocation = lsTool.build({ path: testPath });
+ const result = await invocation.execute(new AbortController().signal);
const lines = (
typeof result.llmContent === 'string' ? result.llmContent : ''
@@ -361,24 +350,18 @@ describe('LSTool', () => {
throw new Error('EACCES: permission denied');
});
- const result = await lsTool.execute(
- { path: testPath },
- new AbortController().signal,
- );
+ const invocation = lsTool.build({ path: testPath });
+ const result = await invocation.execute(new AbortController().signal);
expect(result.llmContent).toContain('Error listing directory');
expect(result.llmContent).toContain('permission denied');
expect(result.returnDisplay).toBe('Error: Failed to list directory.');
});
- it('should validate parameters and return error for invalid params', async () => {
- const result = await lsTool.execute(
- { path: '../outside' },
- new AbortController().signal,
+ it('should throw for invalid params at build time', async () => {
+ expect(() => lsTool.build({ path: '../outside' })).toThrow(
+ 'Path must be absolute: ../outside',
);
-
- expect(result.llmContent).toContain('Invalid parameters provided');
- expect(result.returnDisplay).toBe('Error: Failed to execute tool.');
});
it('should handle errors accessing individual files during listing', async () => {
@@ -406,10 +389,8 @@ describe('LSTool', () => {
.spyOn(console, 'error')
.mockImplementation(() => {});
- const result = await lsTool.execute(
- { path: testPath },
- new AbortController().signal,
- );
+ const invocation = lsTool.build({ path: testPath });
+ const result = await invocation.execute(new AbortController().signal);
// Should still list the accessible file
expect(result.llmContent).toContain('accessible.ts');
@@ -428,19 +409,25 @@ describe('LSTool', () => {
describe('getDescription', () => {
it('should return shortened relative path', () => {
const params = {
- path: path.join(mockPrimaryDir, 'deeply', 'nested', 'directory'),
+ path: `${mockPrimaryDir}/deeply/nested/directory`,
};
-
- const description = lsTool.getDescription(params);
+ vi.mocked(fs.statSync).mockReturnValue({
+ isDirectory: () => true,
+ } as fs.Stats);
+ const invocation = lsTool.build(params);
+ const description = invocation.getDescription();
expect(description).toBe(path.join('deeply', 'nested', 'directory'));
});
it('should handle paths in secondary workspace', () => {
const params = {
- path: path.join(mockSecondaryDir, 'lib'),
+ path: `${mockSecondaryDir}/lib`,
};
-
- const description = lsTool.getDescription(params);
+ vi.mocked(fs.statSync).mockReturnValue({
+ isDirectory: () => true,
+ } as fs.Stats);
+ const invocation = lsTool.build(params);
+ const description = invocation.getDescription();
expect(description).toBe(path.join('..', 'other-project', 'lib'));
});
});
@@ -448,22 +435,25 @@ describe('LSTool', () => {
describe('workspace boundary validation', () => {
it('should accept paths in primary workspace directory', () => {
const params = { path: `${mockPrimaryDir}/src` };
- expect(lsTool.validateToolParams(params)).toBeNull();
+ vi.mocked(fs.statSync).mockReturnValue({
+ isDirectory: () => true,
+ } as fs.Stats);
+ expect(lsTool.build(params)).toBeDefined();
});
it('should accept paths in secondary workspace directory', () => {
const params = { path: `${mockSecondaryDir}/lib` };
- expect(lsTool.validateToolParams(params)).toBeNull();
+ vi.mocked(fs.statSync).mockReturnValue({
+ isDirectory: () => true,
+ } as fs.Stats);
+ expect(lsTool.build(params)).toBeDefined();
});
it('should reject paths outside all workspace directories', () => {
const params = { path: '/etc/passwd' };
- const error = lsTool.validateToolParams(params);
- expect(error).toContain(
+ expect(() => lsTool.build(params)).toThrow(
'Path must be within one of the workspace directories',
);
- expect(error).toContain(mockPrimaryDir);
- expect(error).toContain(mockSecondaryDir);
});
it('should list files from secondary workspace directory', async () => {
@@ -483,10 +473,8 @@ describe('LSTool', () => {
vi.mocked(fs.readdirSync).mockReturnValue(mockFiles as any);
- const result = await lsTool.execute(
- { path: testPath },
- new AbortController().signal,
- );
+ const invocation = lsTool.build({ path: testPath });
+ const result = await invocation.execute(new AbortController().signal);
expect(result.llmContent).toContain('test1.spec.ts');
expect(result.llmContent).toContain('test2.spec.ts');
diff --git a/packages/core/src/tools/ls.ts b/packages/core/src/tools/ls.ts
index 7a4445a5..2618136a 100644
--- a/packages/core/src/tools/ls.ts
+++ b/packages/core/src/tools/ls.ts
@@ -6,7 +6,13 @@
import fs from 'fs';
import path from 'path';
-import { BaseTool, Kind, ToolResult } from './tools.js';
+import {
+ BaseDeclarativeTool,
+ BaseToolInvocation,
+ Kind,
+ ToolInvocation,
+ ToolResult,
+} from './tools.js';
import { SchemaValidator } from '../utils/schemaValidator.js';
import { makeRelative, shortenPath } from '../utils/paths.js';
import { Config, DEFAULT_FILE_FILTERING_OPTIONS } from '../config/config.js';
@@ -64,79 +70,12 @@ export interface FileEntry {
modifiedTime: Date;
}
-/**
- * Implementation of the LS tool logic
- */
-export class LSTool extends BaseTool<LSToolParams, ToolResult> {
- static readonly Name = 'list_directory';
-
- constructor(private config: Config) {
- super(
- LSTool.Name,
- 'ReadFolder',
- 'Lists the names of files and subdirectories directly within a specified directory path. Can optionally ignore entries matching provided glob patterns.',
- Kind.Search,
- {
- properties: {
- path: {
- description:
- 'The absolute path to the directory to list (must be absolute, not relative)',
- type: 'string',
- },
- ignore: {
- description: 'List of glob patterns to ignore',
- items: {
- type: 'string',
- },
- type: 'array',
- },
- file_filtering_options: {
- description:
- 'Optional: Whether to respect ignore patterns from .gitignore or .geminiignore',
- type: 'object',
- properties: {
- respect_git_ignore: {
- description:
- 'Optional: Whether to respect .gitignore patterns when listing files. Only available in git repositories. Defaults to true.',
- type: 'boolean',
- },
- respect_gemini_ignore: {
- description:
- 'Optional: Whether to respect .geminiignore patterns when listing files. Defaults to true.',
- type: 'boolean',
- },
- },
- },
- },
- required: ['path'],
- type: 'object',
- },
- );
- }
-
- /**
- * Validates the parameters for the tool
- * @param params Parameters to validate
- * @returns An error message string if invalid, null otherwise
- */
- validateToolParams(params: LSToolParams): string | null {
- const errors = SchemaValidator.validate(
- this.schema.parametersJsonSchema,
- params,
- );
- if (errors) {
- return errors;
- }
- if (!path.isAbsolute(params.path)) {
- return `Path must be absolute: ${params.path}`;
- }
-
- const workspaceContext = this.config.getWorkspaceContext();
- if (!workspaceContext.isPathWithinWorkspace(params.path)) {
- const directories = workspaceContext.getDirectories();
- return `Path must be within one of the workspace directories: ${directories.join(', ')}`;
- }
- return null;
+class LSToolInvocation extends BaseToolInvocation<LSToolParams, ToolResult> {
+ constructor(
+ private readonly config: Config,
+ params: LSToolParams,
+ ) {
+ super(params);
}
/**
@@ -165,11 +104,13 @@ export class LSTool extends BaseTool<LSToolParams, ToolResult> {
/**
* Gets a description of the file reading operation
- * @param params Parameters for the file reading
* @returns A string describing the file being read
*/
- getDescription(params: LSToolParams): string {
- const relativePath = makeRelative(params.path, this.config.getTargetDir());
+ getDescription(): string {
+ const relativePath = makeRelative(
+ this.params.path,
+ this.config.getTargetDir(),
+ );
return shortenPath(relativePath);
}
@@ -184,49 +125,37 @@ export class LSTool extends BaseTool<LSToolParams, ToolResult> {
/**
* Executes the LS operation with the given parameters
- * @param params Parameters for the LS operation
* @returns Result of the LS operation
*/
- async execute(
- params: LSToolParams,
- _signal: AbortSignal,
- ): Promise<ToolResult> {
- const validationError = this.validateToolParams(params);
- if (validationError) {
- return this.errorResult(
- `Error: Invalid parameters provided. Reason: ${validationError}`,
- `Failed to execute tool.`,
- );
- }
-
+ async execute(_signal: AbortSignal): Promise<ToolResult> {
try {
- const stats = fs.statSync(params.path);
+ const stats = fs.statSync(this.params.path);
if (!stats) {
// fs.statSync throws on non-existence, so this check might be redundant
// but keeping for clarity. Error message adjusted.
return this.errorResult(
- `Error: Directory not found or inaccessible: ${params.path}`,
+ `Error: Directory not found or inaccessible: ${this.params.path}`,
`Directory not found or inaccessible.`,
);
}
if (!stats.isDirectory()) {
return this.errorResult(
- `Error: Path is not a directory: ${params.path}`,
+ `Error: Path is not a directory: ${this.params.path}`,
`Path is not a directory.`,
);
}
- const files = fs.readdirSync(params.path);
+ const files = fs.readdirSync(this.params.path);
const defaultFileIgnores =
this.config.getFileFilteringOptions() ?? DEFAULT_FILE_FILTERING_OPTIONS;
const fileFilteringOptions = {
respectGitIgnore:
- params.file_filtering_options?.respect_git_ignore ??
+ this.params.file_filtering_options?.respect_git_ignore ??
defaultFileIgnores.respectGitIgnore,
respectGeminiIgnore:
- params.file_filtering_options?.respect_gemini_ignore ??
+ this.params.file_filtering_options?.respect_gemini_ignore ??
defaultFileIgnores.respectGeminiIgnore,
};
@@ -241,17 +170,17 @@ export class LSTool extends BaseTool<LSToolParams, ToolResult> {
if (files.length === 0) {
// Changed error message to be more neutral for LLM
return {
- llmContent: `Directory ${params.path} is empty.`,
+ llmContent: `Directory ${this.params.path} is empty.`,
returnDisplay: `Directory is empty.`,
};
}
for (const file of files) {
- if (this.shouldIgnore(file, params.ignore)) {
+ if (this.shouldIgnore(file, this.params.ignore)) {
continue;
}
- const fullPath = path.join(params.path, file);
+ const fullPath = path.join(this.params.path, file);
const relativePath = path.relative(
this.config.getTargetDir(),
fullPath,
@@ -301,7 +230,7 @@ export class LSTool extends BaseTool<LSToolParams, ToolResult> {
.map((entry) => `${entry.isDirectory ? '[DIR] ' : ''}${entry.name}`)
.join('\n');
- let resultMessage = `Directory listing for ${params.path}:\n${directoryContent}`;
+ let resultMessage = `Directory listing for ${this.params.path}:\n${directoryContent}`;
const ignoredMessages = [];
if (gitIgnoredCount > 0) {
ignoredMessages.push(`${gitIgnoredCount} git-ignored`);
@@ -329,3 +258,87 @@ export class LSTool extends BaseTool<LSToolParams, ToolResult> {
}
}
}
+
+/**
+ * Implementation of the LS tool logic
+ */
+export class LSTool extends BaseDeclarativeTool<LSToolParams, ToolResult> {
+ static readonly Name = 'list_directory';
+
+ constructor(private config: Config) {
+ super(
+ LSTool.Name,
+ 'ReadFolder',
+ 'Lists the names of files and subdirectories directly within a specified directory path. Can optionally ignore entries matching provided glob patterns.',
+ Kind.Search,
+ {
+ properties: {
+ path: {
+ description:
+ 'The absolute path to the directory to list (must be absolute, not relative)',
+ type: 'string',
+ },
+ ignore: {
+ description: 'List of glob patterns to ignore',
+ items: {
+ type: 'string',
+ },
+ type: 'array',
+ },
+ file_filtering_options: {
+ description:
+ 'Optional: Whether to respect ignore patterns from .gitignore or .geminiignore',
+ type: 'object',
+ properties: {
+ respect_git_ignore: {
+ description:
+ 'Optional: Whether to respect .gitignore patterns when listing files. Only available in git repositories. Defaults to true.',
+ type: 'boolean',
+ },
+ respect_gemini_ignore: {
+ description:
+ 'Optional: Whether to respect .geminiignore patterns when listing files. Defaults to true.',
+ type: 'boolean',
+ },
+ },
+ },
+ },
+ required: ['path'],
+ type: 'object',
+ },
+ );
+ }
+
+ /**
+ * Validates the parameters for the tool
+ * @param params Parameters to validate
+ * @returns An error message string if invalid, null otherwise
+ */
+ validateToolParams(params: LSToolParams): string | null {
+ const errors = SchemaValidator.validate(
+ this.schema.parametersJsonSchema,
+ params,
+ );
+ if (errors) {
+ return errors;
+ }
+ if (!path.isAbsolute(params.path)) {
+ return `Path must be absolute: ${params.path}`;
+ }
+
+ const workspaceContext = this.config.getWorkspaceContext();
+ if (!workspaceContext.isPathWithinWorkspace(params.path)) {
+ const directories = workspaceContext.getDirectories();
+ return `Path must be within one of the workspace directories: ${directories.join(
+ ', ',
+ )}`;
+ }
+ return null;
+ }
+
+ protected createInvocation(
+ params: LSToolParams,
+ ): ToolInvocation<LSToolParams, ToolResult> {
+ return new LSToolInvocation(this.config, params);
+ }
+}
diff --git a/packages/core/src/tools/mcp-tool.test.ts b/packages/core/src/tools/mcp-tool.test.ts
index f8a9a8ba..36602d49 100644
--- a/packages/core/src/tools/mcp-tool.test.ts
+++ b/packages/core/src/tools/mcp-tool.test.ts
@@ -73,11 +73,21 @@ describe('DiscoveredMCPTool', () => {
required: ['param'],
};
+ let tool: DiscoveredMCPTool;
+
beforeEach(() => {
mockCallTool.mockClear();
mockToolMethod.mockClear();
+ tool = new DiscoveredMCPTool(
+ mockCallableToolInstance,
+ serverName,
+ serverToolName,
+ baseDescription,
+ inputSchema,
+ );
// Clear allowlist before each relevant test, especially for shouldConfirmExecute
- (DiscoveredMCPTool as any).allowlist.clear();
+ const invocation = tool.build({}) as any;
+ invocation.constructor.allowlist.clear();
});
afterEach(() => {
@@ -86,14 +96,6 @@ describe('DiscoveredMCPTool', () => {
describe('constructor', () => {
it('should set properties correctly', () => {
- const tool = new DiscoveredMCPTool(
- mockCallableToolInstance,
- serverName,
- serverToolName,
- baseDescription,
- inputSchema,
- );
-
expect(tool.name).toBe(serverToolName);
expect(tool.schema.name).toBe(serverToolName);
expect(tool.schema.description).toBe(baseDescription);
@@ -105,7 +107,7 @@ describe('DiscoveredMCPTool', () => {
it('should accept and store a custom timeout', () => {
const customTimeout = 5000;
- const tool = new DiscoveredMCPTool(
+ const toolWithTimeout = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
@@ -113,19 +115,12 @@ describe('DiscoveredMCPTool', () => {
inputSchema,
customTimeout,
);
- expect(tool.timeout).toBe(customTimeout);
+ expect(toolWithTimeout.timeout).toBe(customTimeout);
});
});
describe('execute', () => {
it('should call mcpTool.callTool with correct parameters and format display output', async () => {
- const tool = new DiscoveredMCPTool(
- mockCallableToolInstance,
- serverName,
- serverToolName,
- baseDescription,
- inputSchema,
- );
const params = { param: 'testValue' };
const mockToolSuccessResultObject = {
success: true,
@@ -147,7 +142,10 @@ describe('DiscoveredMCPTool', () => {
];
mockCallTool.mockResolvedValue(mockMcpToolResponseParts);
- const toolResult: ToolResult = await tool.execute(params);
+ const invocation = tool.build(params);
+ const toolResult: ToolResult = await invocation.execute(
+ new AbortController().signal,
+ );
expect(mockCallTool).toHaveBeenCalledWith([
{ name: serverToolName, args: params },
@@ -163,17 +161,13 @@ describe('DiscoveredMCPTool', () => {
});
it('should handle empty result from getStringifiedResultForDisplay', async () => {
- const tool = new DiscoveredMCPTool(
- mockCallableToolInstance,
- serverName,
- serverToolName,
- baseDescription,
- inputSchema,
- );
const params = { param: 'testValue' };
const mockMcpToolResponsePartsEmpty: Part[] = [];
mockCallTool.mockResolvedValue(mockMcpToolResponsePartsEmpty);
- const toolResult: ToolResult = await tool.execute(params);
+ const invocation = tool.build(params);
+ const toolResult: ToolResult = await invocation.execute(
+ new AbortController().signal,
+ );
expect(toolResult.returnDisplay).toBe('```json\n[]\n```');
expect(toolResult.llmContent).toEqual([
{ text: '[Error: Could not parse tool response]' },
@@ -181,28 +175,17 @@ describe('DiscoveredMCPTool', () => {
});
it('should propagate rejection if mcpTool.callTool rejects', async () => {
- const tool = new DiscoveredMCPTool(
- mockCallableToolInstance,
- serverName,
- serverToolName,
- baseDescription,
- inputSchema,
- );
const params = { param: 'failCase' };
const expectedError = new Error('MCP call failed');
mockCallTool.mockRejectedValue(expectedError);
- await expect(tool.execute(params)).rejects.toThrow(expectedError);
+ const invocation = tool.build(params);
+ await expect(
+ invocation.execute(new AbortController().signal),
+ ).rejects.toThrow(expectedError);
});
it('should handle a simple text response correctly', async () => {
- const tool = new DiscoveredMCPTool(
- mockCallableToolInstance,
- serverName,
- serverToolName,
- baseDescription,
- inputSchema,
- );
const params = { query: 'test' };
const successMessage = 'This is a success message.';
@@ -221,7 +204,8 @@ describe('DiscoveredMCPTool', () => {
];
mockCallTool.mockResolvedValue(sdkResponse);
- const toolResult = await tool.execute(params);
+ const invocation = tool.build(params);
+ const toolResult = await invocation.execute(new AbortController().signal);
// 1. Assert that the llmContent sent to the scheduler is a clean Part array.
expect(toolResult.llmContent).toEqual([{ text: successMessage }]);
@@ -236,13 +220,6 @@ describe('DiscoveredMCPTool', () => {
});
it('should handle an AudioBlock response', async () => {
- const tool = new DiscoveredMCPTool(
- mockCallableToolInstance,
- serverName,
- serverToolName,
- baseDescription,
- inputSchema,
- );
const params = { action: 'play' };
const sdkResponse: Part[] = [
{
@@ -262,7 +239,8 @@ describe('DiscoveredMCPTool', () => {
];
mockCallTool.mockResolvedValue(sdkResponse);
- const toolResult = await tool.execute(params);
+ const invocation = tool.build(params);
+ const toolResult = await invocation.execute(new AbortController().signal);
expect(toolResult.llmContent).toEqual([
{
@@ -279,13 +257,6 @@ describe('DiscoveredMCPTool', () => {
});
it('should handle a ResourceLinkBlock response', async () => {
- const tool = new DiscoveredMCPTool(
- mockCallableToolInstance,
- serverName,
- serverToolName,
- baseDescription,
- inputSchema,
- );
const params = { resource: 'get' };
const sdkResponse: Part[] = [
{
@@ -306,7 +277,8 @@ describe('DiscoveredMCPTool', () => {
];
mockCallTool.mockResolvedValue(sdkResponse);
- const toolResult = await tool.execute(params);
+ const invocation = tool.build(params);
+ const toolResult = await invocation.execute(new AbortController().signal);
expect(toolResult.llmContent).toEqual([
{
@@ -319,13 +291,6 @@ describe('DiscoveredMCPTool', () => {
});
it('should handle an embedded text ResourceBlock response', async () => {
- const tool = new DiscoveredMCPTool(
- mockCallableToolInstance,
- serverName,
- serverToolName,
- baseDescription,
- inputSchema,
- );
const params = { resource: 'get' };
const sdkResponse: Part[] = [
{
@@ -348,7 +313,8 @@ describe('DiscoveredMCPTool', () => {
];
mockCallTool.mockResolvedValue(sdkResponse);
- const toolResult = await tool.execute(params);
+ const invocation = tool.build(params);
+ const toolResult = await invocation.execute(new AbortController().signal);
expect(toolResult.llmContent).toEqual([
{ text: 'This is the text content.' },
@@ -357,13 +323,6 @@ describe('DiscoveredMCPTool', () => {
});
it('should handle an embedded binary ResourceBlock response', async () => {
- const tool = new DiscoveredMCPTool(
- mockCallableToolInstance,
- serverName,
- serverToolName,
- baseDescription,
- inputSchema,
- );
const params = { resource: 'get' };
const sdkResponse: Part[] = [
{
@@ -386,7 +345,8 @@ describe('DiscoveredMCPTool', () => {
];
mockCallTool.mockResolvedValue(sdkResponse);
- const toolResult = await tool.execute(params);
+ const invocation = tool.build(params);
+ const toolResult = await invocation.execute(new AbortController().signal);
expect(toolResult.llmContent).toEqual([
{
@@ -405,13 +365,6 @@ describe('DiscoveredMCPTool', () => {
});
it('should handle a mix of content block types', async () => {
- const tool = new DiscoveredMCPTool(
- mockCallableToolInstance,
- serverName,
- serverToolName,
- baseDescription,
- inputSchema,
- );
const params = { action: 'complex' };
const sdkResponse: Part[] = [
{
@@ -433,7 +386,8 @@ describe('DiscoveredMCPTool', () => {
];
mockCallTool.mockResolvedValue(sdkResponse);
- const toolResult = await tool.execute(params);
+ const invocation = tool.build(params);
+ const toolResult = await invocation.execute(new AbortController().signal);
expect(toolResult.llmContent).toEqual([
{ text: 'First part.' },
@@ -454,13 +408,6 @@ describe('DiscoveredMCPTool', () => {
});
it('should ignore unknown content block types', async () => {
- const tool = new DiscoveredMCPTool(
- mockCallableToolInstance,
- serverName,
- serverToolName,
- baseDescription,
- inputSchema,
- );
const params = { action: 'test' };
const sdkResponse: Part[] = [
{
@@ -477,7 +424,8 @@ describe('DiscoveredMCPTool', () => {
];
mockCallTool.mockResolvedValue(sdkResponse);
- const toolResult = await tool.execute(params);
+ const invocation = tool.build(params);
+ const toolResult = await invocation.execute(new AbortController().signal);
expect(toolResult.llmContent).toEqual([{ text: 'Valid part.' }]);
expect(toolResult.returnDisplay).toBe(
@@ -486,13 +434,6 @@ describe('DiscoveredMCPTool', () => {
});
it('should handle a complex mix of content block types', async () => {
- const tool = new DiscoveredMCPTool(
- mockCallableToolInstance,
- serverName,
- serverToolName,
- baseDescription,
- inputSchema,
- );
const params = { action: 'super-complex' };
const sdkResponse: Part[] = [
{
@@ -527,7 +468,8 @@ describe('DiscoveredMCPTool', () => {
];
mockCallTool.mockResolvedValue(sdkResponse);
- const toolResult = await tool.execute(params);
+ const invocation = tool.build(params);
+ const toolResult = await invocation.execute(new AbortController().signal);
expect(toolResult.llmContent).toEqual([
{ text: 'Here is a resource.' },
@@ -552,10 +494,8 @@ describe('DiscoveredMCPTool', () => {
});
describe('shouldConfirmExecute', () => {
- // beforeEach is already clearing allowlist
-
it('should return false if trust is true', async () => {
- const tool = new DiscoveredMCPTool(
+ const trustedTool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
@@ -564,50 +504,32 @@ describe('DiscoveredMCPTool', () => {
undefined,
true,
);
+ const invocation = trustedTool.build({});
expect(
- await tool.shouldConfirmExecute({}, new AbortController().signal),
+ await invocation.shouldConfirmExecute(new AbortController().signal),
).toBe(false);
});
it('should return false if server is allowlisted', async () => {
- (DiscoveredMCPTool as any).allowlist.add(serverName);
- const tool = new DiscoveredMCPTool(
- mockCallableToolInstance,
- serverName,
- serverToolName,
- baseDescription,
- inputSchema,
- );
+ const invocation = tool.build({}) as any;
+ invocation.constructor.allowlist.add(serverName);
expect(
- await tool.shouldConfirmExecute({}, new AbortController().signal),
+ await invocation.shouldConfirmExecute(new AbortController().signal),
).toBe(false);
});
it('should return false if tool is allowlisted', async () => {
const toolAllowlistKey = `${serverName}.${serverToolName}`;
- (DiscoveredMCPTool as any).allowlist.add(toolAllowlistKey);
- const tool = new DiscoveredMCPTool(
- mockCallableToolInstance,
- serverName,
- serverToolName,
- baseDescription,
- inputSchema,
- );
+ const invocation = tool.build({}) as any;
+ invocation.constructor.allowlist.add(toolAllowlistKey);
expect(
- await tool.shouldConfirmExecute({}, new AbortController().signal),
+ await invocation.shouldConfirmExecute(new AbortController().signal),
).toBe(false);
});
it('should return confirmation details if not trusted and not allowlisted', async () => {
- const tool = new DiscoveredMCPTool(
- mockCallableToolInstance,
- serverName,
- serverToolName,
- baseDescription,
- inputSchema,
- );
- const confirmation = await tool.shouldConfirmExecute(
- {},
+ const invocation = tool.build({});
+ const confirmation = await invocation.shouldConfirmExecute(
new AbortController().signal,
);
expect(confirmation).not.toBe(false);
@@ -629,15 +551,8 @@ describe('DiscoveredMCPTool', () => {
});
it('should add server to allowlist on ProceedAlwaysServer', async () => {
- const tool = new DiscoveredMCPTool(
- mockCallableToolInstance,
- serverName,
- serverToolName,
- baseDescription,
- inputSchema,
- );
- const confirmation = await tool.shouldConfirmExecute(
- {},
+ const invocation = tool.build({}) as any;
+ const confirmation = await invocation.shouldConfirmExecute(
new AbortController().signal,
);
expect(confirmation).not.toBe(false);
@@ -650,7 +565,7 @@ describe('DiscoveredMCPTool', () => {
await confirmation.onConfirm(
ToolConfirmationOutcome.ProceedAlwaysServer,
);
- expect((DiscoveredMCPTool as any).allowlist.has(serverName)).toBe(true);
+ expect(invocation.constructor.allowlist.has(serverName)).toBe(true);
} else {
throw new Error(
'Confirmation details or onConfirm not in expected format',
@@ -659,16 +574,9 @@ describe('DiscoveredMCPTool', () => {
});
it('should add tool to allowlist on ProceedAlwaysTool', async () => {
- const tool = new DiscoveredMCPTool(
- mockCallableToolInstance,
- serverName,
- serverToolName,
- baseDescription,
- inputSchema,
- );
const toolAllowlistKey = `${serverName}.${serverToolName}`;
- const confirmation = await tool.shouldConfirmExecute(
- {},
+ const invocation = tool.build({}) as any;
+ const confirmation = await invocation.shouldConfirmExecute(
new AbortController().signal,
);
expect(confirmation).not.toBe(false);
@@ -679,7 +587,7 @@ describe('DiscoveredMCPTool', () => {
typeof confirmation.onConfirm === 'function'
) {
await confirmation.onConfirm(ToolConfirmationOutcome.ProceedAlwaysTool);
- expect((DiscoveredMCPTool as any).allowlist.has(toolAllowlistKey)).toBe(
+ expect(invocation.constructor.allowlist.has(toolAllowlistKey)).toBe(
true,
);
} else {
@@ -690,15 +598,8 @@ describe('DiscoveredMCPTool', () => {
});
it('should handle Cancel confirmation outcome', async () => {
- const tool = new DiscoveredMCPTool(
- mockCallableToolInstance,
- serverName,
- serverToolName,
- baseDescription,
- inputSchema,
- );
- const confirmation = await tool.shouldConfirmExecute(
- {},
+ const invocation = tool.build({}) as any;
+ const confirmation = await invocation.shouldConfirmExecute(
new AbortController().signal,
);
expect(confirmation).not.toBe(false);
@@ -710,11 +611,9 @@ describe('DiscoveredMCPTool', () => {
) {
// Cancel should not add anything to allowlist
await confirmation.onConfirm(ToolConfirmationOutcome.Cancel);
- expect((DiscoveredMCPTool as any).allowlist.has(serverName)).toBe(
- false,
- );
+ expect(invocation.constructor.allowlist.has(serverName)).toBe(false);
expect(
- (DiscoveredMCPTool as any).allowlist.has(
+ invocation.constructor.allowlist.has(
`${serverName}.${serverToolName}`,
),
).toBe(false);
@@ -726,15 +625,8 @@ describe('DiscoveredMCPTool', () => {
});
it('should handle ProceedOnce confirmation outcome', async () => {
- const tool = new DiscoveredMCPTool(
- mockCallableToolInstance,
- serverName,
- serverToolName,
- baseDescription,
- inputSchema,
- );
- const confirmation = await tool.shouldConfirmExecute(
- {},
+ const invocation = tool.build({}) as any;
+ const confirmation = await invocation.shouldConfirmExecute(
new AbortController().signal,
);
expect(confirmation).not.toBe(false);
@@ -746,11 +638,9 @@ describe('DiscoveredMCPTool', () => {
) {
// ProceedOnce should not add anything to allowlist
await confirmation.onConfirm(ToolConfirmationOutcome.ProceedOnce);
- expect((DiscoveredMCPTool as any).allowlist.has(serverName)).toBe(
- false,
- );
+ expect(invocation.constructor.allowlist.has(serverName)).toBe(false);
expect(
- (DiscoveredMCPTool as any).allowlist.has(
+ invocation.constructor.allowlist.has(
`${serverName}.${serverToolName}`,
),
).toBe(false);
diff --git a/packages/core/src/tools/mcp-tool.ts b/packages/core/src/tools/mcp-tool.ts
index 59f83db3..01a8d75c 100644
--- a/packages/core/src/tools/mcp-tool.ts
+++ b/packages/core/src/tools/mcp-tool.ts
@@ -5,14 +5,16 @@
*/
import {
- BaseTool,
- ToolResult,
+ BaseDeclarativeTool,
+ BaseToolInvocation,
+ Kind,
ToolCallConfirmationDetails,
ToolConfirmationOutcome,
+ ToolInvocation,
ToolMcpConfirmationDetails,
- Kind,
+ ToolResult,
} from './tools.js';
-import { CallableTool, Part, FunctionCall } from '@google/genai';
+import { CallableTool, FunctionCall, Part } from '@google/genai';
type ToolParams = Record<string, unknown>;
@@ -50,45 +52,25 @@ type McpContentBlock =
| McpResourceBlock
| McpResourceLinkBlock;
-export class DiscoveredMCPTool extends BaseTool<ToolParams, ToolResult> {
+class DiscoveredMCPToolInvocation extends BaseToolInvocation<
+ ToolParams,
+ ToolResult
+> {
private static readonly allowlist: Set<string> = new Set();
constructor(
private readonly mcpTool: CallableTool,
readonly serverName: string,
readonly serverToolName: string,
- description: string,
- readonly parameterSchema: unknown,
+ readonly displayName: string,
readonly timeout?: number,
readonly trust?: boolean,
- nameOverride?: string,
+ params: ToolParams = {},
) {
- super(
- nameOverride ?? generateValidName(serverToolName),
- `${serverToolName} (${serverName} MCP Server)`,
- description,
- Kind.Other,
- parameterSchema,
- true, // isOutputMarkdown
- false, // canUpdateOutput
- );
- }
-
- asFullyQualifiedTool(): DiscoveredMCPTool {
- return new DiscoveredMCPTool(
- this.mcpTool,
- this.serverName,
- this.serverToolName,
- this.description,
- this.parameterSchema,
- this.timeout,
- this.trust,
- `${this.serverName}__${this.serverToolName}`,
- );
+ super(params);
}
async shouldConfirmExecute(
- _params: ToolParams,
_abortSignal: AbortSignal,
): Promise<ToolCallConfirmationDetails | false> {
const serverAllowListKey = this.serverName;
@@ -99,8 +81,8 @@ export class DiscoveredMCPTool extends BaseTool<ToolParams, ToolResult> {
}
if (
- DiscoveredMCPTool.allowlist.has(serverAllowListKey) ||
- DiscoveredMCPTool.allowlist.has(toolAllowListKey)
+ DiscoveredMCPToolInvocation.allowlist.has(serverAllowListKey) ||
+ DiscoveredMCPToolInvocation.allowlist.has(toolAllowListKey)
) {
return false; // server and/or tool already allowlisted
}
@@ -110,23 +92,23 @@ export class DiscoveredMCPTool extends BaseTool<ToolParams, ToolResult> {
title: 'Confirm MCP Tool Execution',
serverName: this.serverName,
toolName: this.serverToolName, // Display original tool name in confirmation
- toolDisplayName: this.name, // Display global registry name exposed to model and user
+ toolDisplayName: this.displayName, // Display global registry name exposed to model and user
onConfirm: async (outcome: ToolConfirmationOutcome) => {
if (outcome === ToolConfirmationOutcome.ProceedAlwaysServer) {
- DiscoveredMCPTool.allowlist.add(serverAllowListKey);
+ DiscoveredMCPToolInvocation.allowlist.add(serverAllowListKey);
} else if (outcome === ToolConfirmationOutcome.ProceedAlwaysTool) {
- DiscoveredMCPTool.allowlist.add(toolAllowListKey);
+ DiscoveredMCPToolInvocation.allowlist.add(toolAllowListKey);
}
},
};
return confirmationDetails;
}
- async execute(params: ToolParams): Promise<ToolResult> {
+ async execute(): Promise<ToolResult> {
const functionCalls: FunctionCall[] = [
{
name: this.serverToolName,
- args: params,
+ args: this.params,
},
];
@@ -138,6 +120,63 @@ export class DiscoveredMCPTool extends BaseTool<ToolParams, ToolResult> {
returnDisplay: getStringifiedResultForDisplay(rawResponseParts),
};
}
+
+ getDescription(): string {
+ return this.displayName;
+ }
+}
+
+export class DiscoveredMCPTool extends BaseDeclarativeTool<
+ ToolParams,
+ ToolResult
+> {
+ constructor(
+ private readonly mcpTool: CallableTool,
+ readonly serverName: string,
+ readonly serverToolName: string,
+ description: string,
+ readonly parameterSchema: unknown,
+ readonly timeout?: number,
+ readonly trust?: boolean,
+ nameOverride?: string,
+ ) {
+ super(
+ nameOverride ?? generateValidName(serverToolName),
+ `${serverToolName} (${serverName} MCP Server)`,
+ description,
+ Kind.Other,
+ parameterSchema,
+ true, // isOutputMarkdown
+ false, // canUpdateOutput
+ );
+ }
+
+ asFullyQualifiedTool(): DiscoveredMCPTool {
+ return new DiscoveredMCPTool(
+ this.mcpTool,
+ this.serverName,
+ this.serverToolName,
+ this.description,
+ this.parameterSchema,
+ this.timeout,
+ this.trust,
+ `${this.serverName}__${this.serverToolName}`,
+ );
+ }
+
+ protected createInvocation(
+ params: ToolParams,
+ ): ToolInvocation<ToolParams, ToolResult> {
+ return new DiscoveredMCPToolInvocation(
+ this.mcpTool,
+ this.serverName,
+ this.serverToolName,
+ this.displayName,
+ this.timeout,
+ this.trust,
+ params,
+ );
+ }
}
function transformTextBlock(block: McpTextBlock): Part {
diff --git a/packages/core/src/tools/memoryTool.test.ts b/packages/core/src/tools/memoryTool.test.ts
index 2a5c4c39..0e382325 100644
--- a/packages/core/src/tools/memoryTool.test.ts
+++ b/packages/core/src/tools/memoryTool.test.ts
@@ -218,7 +218,8 @@ describe('MemoryTool', () => {
it('should call performAddMemoryEntry with correct parameters and return success', async () => {
const params = { fact: 'The sky is blue' };
- const result = await memoryTool.execute(params, mockAbortSignal);
+ const invocation = memoryTool.build(params);
+ const result = await invocation.execute(mockAbortSignal);
// Use getCurrentGeminiMdFilename for the default expectation before any setGeminiMdFilename calls in a test
const expectedFilePath = path.join(
os.homedir(),
@@ -247,14 +248,12 @@ describe('MemoryTool', () => {
it('should return an error if fact is empty', async () => {
const params = { fact: ' ' }; // Empty fact
- const result = await memoryTool.execute(params, mockAbortSignal);
- const errorMessage = 'Parameter "fact" must be a non-empty string.';
-
- expect(performAddMemoryEntrySpy).not.toHaveBeenCalled();
- expect(result.llmContent).toBe(
- JSON.stringify({ success: false, error: errorMessage }),
+ expect(memoryTool.validateToolParams(params)).toBe(
+ 'Parameter "fact" must be a non-empty string.',
+ );
+ expect(() => memoryTool.build(params)).toThrow(
+ 'Parameter "fact" must be a non-empty string.',
);
- expect(result.returnDisplay).toBe(`Error: ${errorMessage}`);
});
it('should handle errors from performAddMemoryEntry', async () => {
@@ -264,7 +263,8 @@ describe('MemoryTool', () => {
);
performAddMemoryEntrySpy.mockRejectedValue(underlyingError);
- const result = await memoryTool.execute(params, mockAbortSignal);
+ const invocation = memoryTool.build(params);
+ const result = await invocation.execute(mockAbortSignal);
expect(result.llmContent).toBe(
JSON.stringify({
@@ -284,17 +284,17 @@ describe('MemoryTool', () => {
beforeEach(() => {
memoryTool = new MemoryTool();
// Clear the allowlist before each test
- (MemoryTool as unknown as { allowlist: Set<string> }).allowlist.clear();
+ const invocation = memoryTool.build({ fact: 'mock-fact' });
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
+ (invocation.constructor as any).allowlist.clear();
// Mock fs.readFile to return empty string (file doesn't exist)
vi.mocked(fs.readFile).mockResolvedValue('');
});
it('should return confirmation details when memory file is not allowlisted', async () => {
const params = { fact: 'Test fact' };
- const result = await memoryTool.shouldConfirmExecute(
- params,
- mockAbortSignal,
- );
+ const invocation = memoryTool.build(params);
+ const result = await invocation.shouldConfirmExecute(mockAbortSignal);
expect(result).toBeDefined();
expect(result).not.toBe(false);
@@ -321,15 +321,12 @@ describe('MemoryTool', () => {
getCurrentGeminiMdFilename(),
);
+ const invocation = memoryTool.build(params);
// Add the memory file to the allowlist
- (MemoryTool as unknown as { allowlist: Set<string> }).allowlist.add(
- memoryFilePath,
- );
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
+ (invocation.constructor as any).allowlist.add(memoryFilePath);
- const result = await memoryTool.shouldConfirmExecute(
- params,
- mockAbortSignal,
- );
+ const result = await invocation.shouldConfirmExecute(mockAbortSignal);
expect(result).toBe(false);
});
@@ -342,10 +339,8 @@ describe('MemoryTool', () => {
getCurrentGeminiMdFilename(),
);
- const result = await memoryTool.shouldConfirmExecute(
- params,
- mockAbortSignal,
- );
+ const invocation = memoryTool.build(params);
+ const result = await invocation.shouldConfirmExecute(mockAbortSignal);
expect(result).toBeDefined();
expect(result).not.toBe(false);
@@ -356,9 +351,8 @@ describe('MemoryTool', () => {
// Check that the memory file was added to the allowlist
expect(
- (MemoryTool as unknown as { allowlist: Set<string> }).allowlist.has(
- memoryFilePath,
- ),
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
+ (invocation.constructor as any).allowlist.has(memoryFilePath),
).toBe(true);
}
});
@@ -371,10 +365,8 @@ describe('MemoryTool', () => {
getCurrentGeminiMdFilename(),
);
- const result = await memoryTool.shouldConfirmExecute(
- params,
- mockAbortSignal,
- );
+ const invocation = memoryTool.build(params);
+ const result = await invocation.shouldConfirmExecute(mockAbortSignal);
expect(result).toBeDefined();
expect(result).not.toBe(false);
@@ -382,18 +374,12 @@ describe('MemoryTool', () => {
if (result && result.type === 'edit') {
// Simulate the onConfirm callback with different outcomes
await result.onConfirm(ToolConfirmationOutcome.ProceedOnce);
- expect(
- (MemoryTool as unknown as { allowlist: Set<string> }).allowlist.has(
- memoryFilePath,
- ),
- ).toBe(false);
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
+ const allowlist = (invocation.constructor as any).allowlist;
+ expect(allowlist.has(memoryFilePath)).toBe(false);
await result.onConfirm(ToolConfirmationOutcome.Cancel);
- expect(
- (MemoryTool as unknown as { allowlist: Set<string> }).allowlist.has(
- memoryFilePath,
- ),
- ).toBe(false);
+ expect(allowlist.has(memoryFilePath)).toBe(false);
}
});
@@ -405,10 +391,8 @@ describe('MemoryTool', () => {
// Mock fs.readFile to return existing content
vi.mocked(fs.readFile).mockResolvedValue(existingContent);
- const result = await memoryTool.shouldConfirmExecute(
- params,
- mockAbortSignal,
- );
+ const invocation = memoryTool.build(params);
+ const result = await invocation.shouldConfirmExecute(mockAbortSignal);
expect(result).toBeDefined();
expect(result).not.toBe(false);
diff --git a/packages/core/src/tools/memoryTool.ts b/packages/core/src/tools/memoryTool.ts
index c8e88c97..a9d765c4 100644
--- a/packages/core/src/tools/memoryTool.ts
+++ b/packages/core/src/tools/memoryTool.ts
@@ -5,11 +5,12 @@
*/
import {
- BaseTool,
+ BaseDeclarativeTool,
+ BaseToolInvocation,
Kind,
- ToolResult,
ToolEditConfirmationDetails,
ToolConfirmationOutcome,
+ ToolResult,
} from './tools.js';
import { FunctionDeclaration } from '@google/genai';
import * as fs from 'fs/promises';
@@ -19,6 +20,7 @@ import * as Diff from 'diff';
import { DEFAULT_DIFF_OPTIONS } from './diffOptions.js';
import { tildeifyPath } from '../utils/paths.js';
import { ModifiableDeclarativeTool, ModifyContext } from './modifiable-tool.js';
+import { SchemaValidator } from '../utils/schemaValidator.js';
const memoryToolSchemaData: FunctionDeclaration = {
name: 'save_memory',
@@ -110,101 +112,86 @@ function ensureNewlineSeparation(currentContent: string): string {
return '\n\n';
}
-export class MemoryTool
- extends BaseTool<SaveMemoryParams, ToolResult>
- implements ModifiableDeclarativeTool<SaveMemoryParams>
-{
- private static readonly allowlist: Set<string> = new Set();
-
- static readonly Name: string = memoryToolSchemaData.name!;
- constructor() {
- super(
- MemoryTool.Name,
- 'Save Memory',
- memoryToolDescription,
- Kind.Think,
- memoryToolSchemaData.parametersJsonSchema as Record<string, unknown>,
- );
+/**
+ * Reads the current content of the memory file
+ */
+async function readMemoryFileContent(): Promise<string> {
+ try {
+ return await fs.readFile(getGlobalMemoryFilePath(), 'utf-8');
+ } catch (err) {
+ const error = err as Error & { code?: string };
+ if (!(error instanceof Error) || error.code !== 'ENOENT') throw err;
+ return '';
}
+}
- getDescription(_params: SaveMemoryParams): string {
- const memoryFilePath = getGlobalMemoryFilePath();
- return `in ${tildeifyPath(memoryFilePath)}`;
- }
+/**
+ * Computes the new content that would result from adding a memory entry
+ */
+function computeNewContent(currentContent: string, fact: string): string {
+ let processedText = fact.trim();
+ processedText = processedText.replace(/^(-+\s*)+/, '').trim();
+ const newMemoryItem = `- ${processedText}`;
- /**
- * Reads the current content of the memory file
- */
- private async readMemoryFileContent(): Promise<string> {
- try {
- return await fs.readFile(getGlobalMemoryFilePath(), 'utf-8');
- } catch (err) {
- const error = err as Error & { code?: string };
- if (!(error instanceof Error) || error.code !== 'ENOENT') throw err;
- return '';
- }
- }
+ const headerIndex = currentContent.indexOf(MEMORY_SECTION_HEADER);
- /**
- * Computes the new content that would result from adding a memory entry
- */
- private computeNewContent(currentContent: string, fact: string): string {
- let processedText = fact.trim();
- processedText = processedText.replace(/^(-+\s*)+/, '').trim();
- const newMemoryItem = `- ${processedText}`;
+ if (headerIndex === -1) {
+ // Header not found, append header and then the entry
+ const separator = ensureNewlineSeparation(currentContent);
+ return (
+ currentContent +
+ `${separator}${MEMORY_SECTION_HEADER}\n${newMemoryItem}\n`
+ );
+ } else {
+ // Header found, find where to insert the new memory entry
+ const startOfSectionContent = headerIndex + MEMORY_SECTION_HEADER.length;
+ let endOfSectionIndex = currentContent.indexOf(
+ '\n## ',
+ startOfSectionContent,
+ );
+ if (endOfSectionIndex === -1) {
+ endOfSectionIndex = currentContent.length; // End of file
+ }
- const headerIndex = currentContent.indexOf(MEMORY_SECTION_HEADER);
+ const beforeSectionMarker = currentContent
+ .substring(0, startOfSectionContent)
+ .trimEnd();
+ let sectionContent = currentContent
+ .substring(startOfSectionContent, endOfSectionIndex)
+ .trimEnd();
+ const afterSectionMarker = currentContent.substring(endOfSectionIndex);
- if (headerIndex === -1) {
- // Header not found, append header and then the entry
- const separator = ensureNewlineSeparation(currentContent);
- return (
- currentContent +
- `${separator}${MEMORY_SECTION_HEADER}\n${newMemoryItem}\n`
- );
- } else {
- // Header found, find where to insert the new memory entry
- const startOfSectionContent = headerIndex + MEMORY_SECTION_HEADER.length;
- let endOfSectionIndex = currentContent.indexOf(
- '\n## ',
- startOfSectionContent,
- );
- if (endOfSectionIndex === -1) {
- endOfSectionIndex = currentContent.length; // End of file
- }
+ sectionContent += `\n${newMemoryItem}`;
+ return (
+ `${beforeSectionMarker}\n${sectionContent.trimStart()}\n${afterSectionMarker}`.trimEnd() +
+ '\n'
+ );
+ }
+}
- const beforeSectionMarker = currentContent
- .substring(0, startOfSectionContent)
- .trimEnd();
- let sectionContent = currentContent
- .substring(startOfSectionContent, endOfSectionIndex)
- .trimEnd();
- const afterSectionMarker = currentContent.substring(endOfSectionIndex);
+class MemoryToolInvocation extends BaseToolInvocation<
+ SaveMemoryParams,
+ ToolResult
+> {
+ private static readonly allowlist: Set<string> = new Set();
- sectionContent += `\n${newMemoryItem}`;
- return (
- `${beforeSectionMarker}\n${sectionContent.trimStart()}\n${afterSectionMarker}`.trimEnd() +
- '\n'
- );
- }
+ getDescription(): string {
+ const memoryFilePath = getGlobalMemoryFilePath();
+ return `in ${tildeifyPath(memoryFilePath)}`;
}
async shouldConfirmExecute(
- params: SaveMemoryParams,
_abortSignal: AbortSignal,
): Promise<ToolEditConfirmationDetails | false> {
const memoryFilePath = getGlobalMemoryFilePath();
const allowlistKey = memoryFilePath;
- if (MemoryTool.allowlist.has(allowlistKey)) {
+ if (MemoryToolInvocation.allowlist.has(allowlistKey)) {
return false;
}
- // Read current content of the memory file
- const currentContent = await this.readMemoryFileContent();
-
- // Calculate the new content that will be written to the memory file
- const newContent = this.computeNewContent(currentContent, params.fact);
+ const currentContent = await readMemoryFileContent();
+ const newContent = computeNewContent(currentContent, this.params.fact);
const fileName = path.basename(memoryFilePath);
const fileDiff = Diff.createPatch(
@@ -226,13 +213,107 @@ export class MemoryTool
newContent,
onConfirm: async (outcome: ToolConfirmationOutcome) => {
if (outcome === ToolConfirmationOutcome.ProceedAlways) {
- MemoryTool.allowlist.add(allowlistKey);
+ MemoryToolInvocation.allowlist.add(allowlistKey);
}
},
};
return confirmationDetails;
}
+ async execute(_signal: AbortSignal): Promise<ToolResult> {
+ const { fact, modified_by_user, modified_content } = this.params;
+
+ try {
+ if (modified_by_user && modified_content !== undefined) {
+ // User modified the content in external editor, write it directly
+ await fs.mkdir(path.dirname(getGlobalMemoryFilePath()), {
+ recursive: true,
+ });
+ await fs.writeFile(
+ getGlobalMemoryFilePath(),
+ modified_content,
+ 'utf-8',
+ );
+ const successMessage = `Okay, I've updated the memory file with your modifications.`;
+ return {
+ llmContent: JSON.stringify({
+ success: true,
+ message: successMessage,
+ }),
+ returnDisplay: successMessage,
+ };
+ } else {
+ // Use the normal memory entry logic
+ await MemoryTool.performAddMemoryEntry(
+ fact,
+ getGlobalMemoryFilePath(),
+ {
+ readFile: fs.readFile,
+ writeFile: fs.writeFile,
+ mkdir: fs.mkdir,
+ },
+ );
+ const successMessage = `Okay, I've remembered that: "${fact}"`;
+ return {
+ llmContent: JSON.stringify({
+ success: true,
+ message: successMessage,
+ }),
+ returnDisplay: successMessage,
+ };
+ }
+ } catch (error) {
+ const errorMessage =
+ error instanceof Error ? error.message : String(error);
+ console.error(
+ `[MemoryTool] Error executing save_memory for fact "${fact}": ${errorMessage}`,
+ );
+ return {
+ llmContent: JSON.stringify({
+ success: false,
+ error: `Failed to save memory. Detail: ${errorMessage}`,
+ }),
+ returnDisplay: `Error saving memory: ${errorMessage}`,
+ };
+ }
+ }
+}
+
+export class MemoryTool
+ extends BaseDeclarativeTool<SaveMemoryParams, ToolResult>
+ implements ModifiableDeclarativeTool<SaveMemoryParams>
+{
+ static readonly Name: string = memoryToolSchemaData.name!;
+ constructor() {
+ super(
+ MemoryTool.Name,
+ 'Save Memory',
+ memoryToolDescription,
+ Kind.Think,
+ memoryToolSchemaData.parametersJsonSchema as Record<string, unknown>,
+ );
+ }
+
+ validateToolParams(params: SaveMemoryParams): string | null {
+ const errors = SchemaValidator.validate(
+ this.schema.parametersJsonSchema,
+ params,
+ );
+ if (errors) {
+ return errors;
+ }
+
+ if (params.fact.trim() === '') {
+ return 'Parameter "fact" must be a non-empty string.';
+ }
+
+ return null;
+ }
+
+ protected createInvocation(params: SaveMemoryParams) {
+ return new MemoryToolInvocation(params);
+ }
+
static async performAddMemoryEntry(
text: string,
memoryFilePath: string,
@@ -303,83 +384,14 @@ export class MemoryTool
}
}
- async execute(
- params: SaveMemoryParams,
- _signal: AbortSignal,
- ): Promise<ToolResult> {
- const { fact, modified_by_user, modified_content } = params;
-
- if (!fact || typeof fact !== 'string' || fact.trim() === '') {
- const errorMessage = 'Parameter "fact" must be a non-empty string.';
- return {
- llmContent: JSON.stringify({ success: false, error: errorMessage }),
- returnDisplay: `Error: ${errorMessage}`,
- };
- }
-
- try {
- if (modified_by_user && modified_content !== undefined) {
- // User modified the content in external editor, write it directly
- await fs.mkdir(path.dirname(getGlobalMemoryFilePath()), {
- recursive: true,
- });
- await fs.writeFile(
- getGlobalMemoryFilePath(),
- modified_content,
- 'utf-8',
- );
- const successMessage = `Okay, I've updated the memory file with your modifications.`;
- return {
- llmContent: JSON.stringify({
- success: true,
- message: successMessage,
- }),
- returnDisplay: successMessage,
- };
- } else {
- // Use the normal memory entry logic
- await MemoryTool.performAddMemoryEntry(
- fact,
- getGlobalMemoryFilePath(),
- {
- readFile: fs.readFile,
- writeFile: fs.writeFile,
- mkdir: fs.mkdir,
- },
- );
- const successMessage = `Okay, I've remembered that: "${fact}"`;
- return {
- llmContent: JSON.stringify({
- success: true,
- message: successMessage,
- }),
- returnDisplay: successMessage,
- };
- }
- } catch (error) {
- const errorMessage =
- error instanceof Error ? error.message : String(error);
- console.error(
- `[MemoryTool] Error executing save_memory for fact "${fact}": ${errorMessage}`,
- );
- return {
- llmContent: JSON.stringify({
- success: false,
- error: `Failed to save memory. Detail: ${errorMessage}`,
- }),
- returnDisplay: `Error saving memory: ${errorMessage}`,
- };
- }
- }
-
getModifyContext(_abortSignal: AbortSignal): ModifyContext<SaveMemoryParams> {
return {
getFilePath: (_params: SaveMemoryParams) => getGlobalMemoryFilePath(),
getCurrentContent: async (_params: SaveMemoryParams): Promise<string> =>
- this.readMemoryFileContent(),
+ readMemoryFileContent(),
getProposedContent: async (params: SaveMemoryParams): Promise<string> => {
- const currentContent = await this.readMemoryFileContent();
- return this.computeNewContent(currentContent, params.fact);
+ const currentContent = await readMemoryFileContent();
+ return computeNewContent(currentContent, params.fact);
},
createUpdatedParams: (
_oldContent: string,