diff options
Diffstat (limited to 'packages/core/src/tools/memoryTool.ts')
| -rw-r--r-- | packages/core/src/tools/memoryTool.ts | 314 |
1 files changed, 163 insertions, 151 deletions
diff --git a/packages/core/src/tools/memoryTool.ts b/packages/core/src/tools/memoryTool.ts index c8e88c97..a9d765c4 100644 --- a/packages/core/src/tools/memoryTool.ts +++ b/packages/core/src/tools/memoryTool.ts @@ -5,11 +5,12 @@ */ import { - BaseTool, + BaseDeclarativeTool, + BaseToolInvocation, Kind, - ToolResult, ToolEditConfirmationDetails, ToolConfirmationOutcome, + ToolResult, } from './tools.js'; import { FunctionDeclaration } from '@google/genai'; import * as fs from 'fs/promises'; @@ -19,6 +20,7 @@ import * as Diff from 'diff'; import { DEFAULT_DIFF_OPTIONS } from './diffOptions.js'; import { tildeifyPath } from '../utils/paths.js'; import { ModifiableDeclarativeTool, ModifyContext } from './modifiable-tool.js'; +import { SchemaValidator } from '../utils/schemaValidator.js'; const memoryToolSchemaData: FunctionDeclaration = { name: 'save_memory', @@ -110,101 +112,86 @@ function ensureNewlineSeparation(currentContent: string): string { return '\n\n'; } -export class MemoryTool - extends BaseTool<SaveMemoryParams, ToolResult> - implements ModifiableDeclarativeTool<SaveMemoryParams> -{ - private static readonly allowlist: Set<string> = new Set(); - - static readonly Name: string = memoryToolSchemaData.name!; - constructor() { - super( - MemoryTool.Name, - 'Save Memory', - memoryToolDescription, - Kind.Think, - memoryToolSchemaData.parametersJsonSchema as Record<string, unknown>, - ); +/** + * Reads the current content of the memory file + */ +async function readMemoryFileContent(): Promise<string> { + try { + return await fs.readFile(getGlobalMemoryFilePath(), 'utf-8'); + } catch (err) { + const error = err as Error & { code?: string }; + if (!(error instanceof Error) || error.code !== 'ENOENT') throw err; + return ''; } +} - getDescription(_params: SaveMemoryParams): string { - const memoryFilePath = getGlobalMemoryFilePath(); - return `in ${tildeifyPath(memoryFilePath)}`; - } +/** + * Computes the new content that would result from adding a memory entry + */ +function computeNewContent(currentContent: string, fact: string): string { + let processedText = fact.trim(); + processedText = processedText.replace(/^(-+\s*)+/, '').trim(); + const newMemoryItem = `- ${processedText}`; - /** - * Reads the current content of the memory file - */ - private async readMemoryFileContent(): Promise<string> { - try { - return await fs.readFile(getGlobalMemoryFilePath(), 'utf-8'); - } catch (err) { - const error = err as Error & { code?: string }; - if (!(error instanceof Error) || error.code !== 'ENOENT') throw err; - return ''; - } - } + const headerIndex = currentContent.indexOf(MEMORY_SECTION_HEADER); - /** - * Computes the new content that would result from adding a memory entry - */ - private computeNewContent(currentContent: string, fact: string): string { - let processedText = fact.trim(); - processedText = processedText.replace(/^(-+\s*)+/, '').trim(); - const newMemoryItem = `- ${processedText}`; + if (headerIndex === -1) { + // Header not found, append header and then the entry + const separator = ensureNewlineSeparation(currentContent); + return ( + currentContent + + `${separator}${MEMORY_SECTION_HEADER}\n${newMemoryItem}\n` + ); + } else { + // Header found, find where to insert the new memory entry + const startOfSectionContent = headerIndex + MEMORY_SECTION_HEADER.length; + let endOfSectionIndex = currentContent.indexOf( + '\n## ', + startOfSectionContent, + ); + if (endOfSectionIndex === -1) { + endOfSectionIndex = currentContent.length; // End of file + } - const headerIndex = currentContent.indexOf(MEMORY_SECTION_HEADER); + const beforeSectionMarker = currentContent + .substring(0, startOfSectionContent) + .trimEnd(); + let sectionContent = currentContent + .substring(startOfSectionContent, endOfSectionIndex) + .trimEnd(); + const afterSectionMarker = currentContent.substring(endOfSectionIndex); - if (headerIndex === -1) { - // Header not found, append header and then the entry - const separator = ensureNewlineSeparation(currentContent); - return ( - currentContent + - `${separator}${MEMORY_SECTION_HEADER}\n${newMemoryItem}\n` - ); - } else { - // Header found, find where to insert the new memory entry - const startOfSectionContent = headerIndex + MEMORY_SECTION_HEADER.length; - let endOfSectionIndex = currentContent.indexOf( - '\n## ', - startOfSectionContent, - ); - if (endOfSectionIndex === -1) { - endOfSectionIndex = currentContent.length; // End of file - } + sectionContent += `\n${newMemoryItem}`; + return ( + `${beforeSectionMarker}\n${sectionContent.trimStart()}\n${afterSectionMarker}`.trimEnd() + + '\n' + ); + } +} - const beforeSectionMarker = currentContent - .substring(0, startOfSectionContent) - .trimEnd(); - let sectionContent = currentContent - .substring(startOfSectionContent, endOfSectionIndex) - .trimEnd(); - const afterSectionMarker = currentContent.substring(endOfSectionIndex); +class MemoryToolInvocation extends BaseToolInvocation< + SaveMemoryParams, + ToolResult +> { + private static readonly allowlist: Set<string> = new Set(); - sectionContent += `\n${newMemoryItem}`; - return ( - `${beforeSectionMarker}\n${sectionContent.trimStart()}\n${afterSectionMarker}`.trimEnd() + - '\n' - ); - } + getDescription(): string { + const memoryFilePath = getGlobalMemoryFilePath(); + return `in ${tildeifyPath(memoryFilePath)}`; } async shouldConfirmExecute( - params: SaveMemoryParams, _abortSignal: AbortSignal, ): Promise<ToolEditConfirmationDetails | false> { const memoryFilePath = getGlobalMemoryFilePath(); const allowlistKey = memoryFilePath; - if (MemoryTool.allowlist.has(allowlistKey)) { + if (MemoryToolInvocation.allowlist.has(allowlistKey)) { return false; } - // Read current content of the memory file - const currentContent = await this.readMemoryFileContent(); - - // Calculate the new content that will be written to the memory file - const newContent = this.computeNewContent(currentContent, params.fact); + const currentContent = await readMemoryFileContent(); + const newContent = computeNewContent(currentContent, this.params.fact); const fileName = path.basename(memoryFilePath); const fileDiff = Diff.createPatch( @@ -226,13 +213,107 @@ export class MemoryTool newContent, onConfirm: async (outcome: ToolConfirmationOutcome) => { if (outcome === ToolConfirmationOutcome.ProceedAlways) { - MemoryTool.allowlist.add(allowlistKey); + MemoryToolInvocation.allowlist.add(allowlistKey); } }, }; return confirmationDetails; } + async execute(_signal: AbortSignal): Promise<ToolResult> { + const { fact, modified_by_user, modified_content } = this.params; + + try { + if (modified_by_user && modified_content !== undefined) { + // User modified the content in external editor, write it directly + await fs.mkdir(path.dirname(getGlobalMemoryFilePath()), { + recursive: true, + }); + await fs.writeFile( + getGlobalMemoryFilePath(), + modified_content, + 'utf-8', + ); + const successMessage = `Okay, I've updated the memory file with your modifications.`; + return { + llmContent: JSON.stringify({ + success: true, + message: successMessage, + }), + returnDisplay: successMessage, + }; + } else { + // Use the normal memory entry logic + await MemoryTool.performAddMemoryEntry( + fact, + getGlobalMemoryFilePath(), + { + readFile: fs.readFile, + writeFile: fs.writeFile, + mkdir: fs.mkdir, + }, + ); + const successMessage = `Okay, I've remembered that: "${fact}"`; + return { + llmContent: JSON.stringify({ + success: true, + message: successMessage, + }), + returnDisplay: successMessage, + }; + } + } catch (error) { + const errorMessage = + error instanceof Error ? error.message : String(error); + console.error( + `[MemoryTool] Error executing save_memory for fact "${fact}": ${errorMessage}`, + ); + return { + llmContent: JSON.stringify({ + success: false, + error: `Failed to save memory. Detail: ${errorMessage}`, + }), + returnDisplay: `Error saving memory: ${errorMessage}`, + }; + } + } +} + +export class MemoryTool + extends BaseDeclarativeTool<SaveMemoryParams, ToolResult> + implements ModifiableDeclarativeTool<SaveMemoryParams> +{ + static readonly Name: string = memoryToolSchemaData.name!; + constructor() { + super( + MemoryTool.Name, + 'Save Memory', + memoryToolDescription, + Kind.Think, + memoryToolSchemaData.parametersJsonSchema as Record<string, unknown>, + ); + } + + validateToolParams(params: SaveMemoryParams): string | null { + const errors = SchemaValidator.validate( + this.schema.parametersJsonSchema, + params, + ); + if (errors) { + return errors; + } + + if (params.fact.trim() === '') { + return 'Parameter "fact" must be a non-empty string.'; + } + + return null; + } + + protected createInvocation(params: SaveMemoryParams) { + return new MemoryToolInvocation(params); + } + static async performAddMemoryEntry( text: string, memoryFilePath: string, @@ -303,83 +384,14 @@ export class MemoryTool } } - async execute( - params: SaveMemoryParams, - _signal: AbortSignal, - ): Promise<ToolResult> { - const { fact, modified_by_user, modified_content } = params; - - if (!fact || typeof fact !== 'string' || fact.trim() === '') { - const errorMessage = 'Parameter "fact" must be a non-empty string.'; - return { - llmContent: JSON.stringify({ success: false, error: errorMessage }), - returnDisplay: `Error: ${errorMessage}`, - }; - } - - try { - if (modified_by_user && modified_content !== undefined) { - // User modified the content in external editor, write it directly - await fs.mkdir(path.dirname(getGlobalMemoryFilePath()), { - recursive: true, - }); - await fs.writeFile( - getGlobalMemoryFilePath(), - modified_content, - 'utf-8', - ); - const successMessage = `Okay, I've updated the memory file with your modifications.`; - return { - llmContent: JSON.stringify({ - success: true, - message: successMessage, - }), - returnDisplay: successMessage, - }; - } else { - // Use the normal memory entry logic - await MemoryTool.performAddMemoryEntry( - fact, - getGlobalMemoryFilePath(), - { - readFile: fs.readFile, - writeFile: fs.writeFile, - mkdir: fs.mkdir, - }, - ); - const successMessage = `Okay, I've remembered that: "${fact}"`; - return { - llmContent: JSON.stringify({ - success: true, - message: successMessage, - }), - returnDisplay: successMessage, - }; - } - } catch (error) { - const errorMessage = - error instanceof Error ? error.message : String(error); - console.error( - `[MemoryTool] Error executing save_memory for fact "${fact}": ${errorMessage}`, - ); - return { - llmContent: JSON.stringify({ - success: false, - error: `Failed to save memory. Detail: ${errorMessage}`, - }), - returnDisplay: `Error saving memory: ${errorMessage}`, - }; - } - } - getModifyContext(_abortSignal: AbortSignal): ModifyContext<SaveMemoryParams> { return { getFilePath: (_params: SaveMemoryParams) => getGlobalMemoryFilePath(), getCurrentContent: async (_params: SaveMemoryParams): Promise<string> => - this.readMemoryFileContent(), + readMemoryFileContent(), getProposedContent: async (params: SaveMemoryParams): Promise<string> => { - const currentContent = await this.readMemoryFileContent(); - return this.computeNewContent(currentContent, params.fact); + const currentContent = await readMemoryFileContent(); + return computeNewContent(currentContent, params.fact); }, createUpdatedParams: ( _oldContent: string, |
