From ac1bb5ee4275e508dfc2256bbd5ca012e4a4f469 Mon Sep 17 00:00:00 2001 From: Olcan Date: Wed, 30 Jul 2025 15:21:31 -0700 Subject: confirm save_memory tool, with ability to see diff and edit manually for advanced changes that may override past memories (#5237) --- packages/core/src/tools/memoryTool.ts | 198 +++++++++++++++++++++++++++++++--- 1 file changed, 184 insertions(+), 14 deletions(-) (limited to 'packages/core/src/tools/memoryTool.ts') diff --git a/packages/core/src/tools/memoryTool.ts b/packages/core/src/tools/memoryTool.ts index f0f1e16b..96509f79 100644 --- a/packages/core/src/tools/memoryTool.ts +++ b/packages/core/src/tools/memoryTool.ts @@ -4,11 +4,21 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { BaseTool, Icon, ToolResult } from './tools.js'; +import { + BaseTool, + ToolResult, + ToolEditConfirmationDetails, + ToolConfirmationOutcome, + Icon, +} from './tools.js'; import { FunctionDeclaration, Type } from '@google/genai'; import * as fs from 'fs/promises'; import * as path from 'path'; import { homedir } from 'os'; +import * as Diff from 'diff'; +import { DEFAULT_DIFF_OPTIONS } from './diffOptions.js'; +import { tildeifyPath } from '../utils/paths.js'; +import { ModifiableTool, ModifyContext } from './modifiable-tool.js'; const memoryToolSchemaData: FunctionDeclaration = { name: 'save_memory', @@ -80,6 +90,8 @@ export function getAllGeminiMdFilenames(): string[] { interface SaveMemoryParams { fact: string; + modified_by_user?: boolean; + modified_content?: string; } function getGlobalMemoryFilePath(): string { @@ -98,7 +110,12 @@ function ensureNewlineSeparation(currentContent: string): string { return '\n\n'; } -export class MemoryTool extends BaseTool { +export class MemoryTool + extends BaseTool + implements ModifiableTool +{ + private static readonly allowlist: Set = new Set(); + static readonly Name: string = memoryToolSchemaData.name!; constructor() { super( @@ -110,6 +127,111 @@ export class MemoryTool extends BaseTool { ); } + getDescription(_params: SaveMemoryParams): string { + const memoryFilePath = getGlobalMemoryFilePath(); + return `in ${tildeifyPath(memoryFilePath)}`; + } + + /** + * Reads the current content of the memory file + */ + private async readMemoryFileContent(): Promise { + try { + return await fs.readFile(getGlobalMemoryFilePath(), 'utf-8'); + } catch (err) { + const error = err as Error & { code?: string }; + if (!(error instanceof Error) || error.code !== 'ENOENT') throw err; + return ''; + } + } + + /** + * Computes the new content that would result from adding a memory entry + */ + private computeNewContent(currentContent: string, fact: string): string { + let processedText = fact.trim(); + processedText = processedText.replace(/^(-+\s*)+/, '').trim(); + const newMemoryItem = `- ${processedText}`; + + const headerIndex = currentContent.indexOf(MEMORY_SECTION_HEADER); + + if (headerIndex === -1) { + // Header not found, append header and then the entry + const separator = ensureNewlineSeparation(currentContent); + return ( + currentContent + + `${separator}${MEMORY_SECTION_HEADER}\n${newMemoryItem}\n` + ); + } else { + // Header found, find where to insert the new memory entry + const startOfSectionContent = headerIndex + MEMORY_SECTION_HEADER.length; + let endOfSectionIndex = currentContent.indexOf( + '\n## ', + startOfSectionContent, + ); + if (endOfSectionIndex === -1) { + endOfSectionIndex = currentContent.length; // End of file + } + + const beforeSectionMarker = currentContent + .substring(0, startOfSectionContent) + .trimEnd(); + let sectionContent = currentContent + .substring(startOfSectionContent, endOfSectionIndex) + .trimEnd(); + const afterSectionMarker = currentContent.substring(endOfSectionIndex); + + sectionContent += `\n${newMemoryItem}`; + return ( + `${beforeSectionMarker}\n${sectionContent.trimStart()}\n${afterSectionMarker}`.trimEnd() + + '\n' + ); + } + } + + async shouldConfirmExecute( + params: SaveMemoryParams, + _abortSignal: AbortSignal, + ): Promise { + const memoryFilePath = getGlobalMemoryFilePath(); + const allowlistKey = memoryFilePath; + + if (MemoryTool.allowlist.has(allowlistKey)) { + return false; + } + + // Read current content of the memory file + const currentContent = await this.readMemoryFileContent(); + + // Calculate the new content that will be written to the memory file + const newContent = this.computeNewContent(currentContent, params.fact); + + const fileName = path.basename(memoryFilePath); + const fileDiff = Diff.createPatch( + fileName, + currentContent, + newContent, + 'Current', + 'Proposed', + DEFAULT_DIFF_OPTIONS, + ); + + const confirmationDetails: ToolEditConfirmationDetails = { + type: 'edit', + title: `Confirm Memory Save: ${tildeifyPath(memoryFilePath)}`, + fileName: memoryFilePath, + fileDiff, + originalContent: currentContent, + newContent, + onConfirm: async (outcome: ToolConfirmationOutcome) => { + if (outcome === ToolConfirmationOutcome.ProceedAlways) { + MemoryTool.allowlist.add(allowlistKey); + } + }, + }; + return confirmationDetails; + } + static async performAddMemoryEntry( text: string, memoryFilePath: string, @@ -184,7 +306,7 @@ export class MemoryTool extends BaseTool { params: SaveMemoryParams, _signal: AbortSignal, ): Promise { - const { fact } = params; + const { fact, modified_by_user, modified_content } = params; if (!fact || typeof fact !== 'string' || fact.trim() === '') { const errorMessage = 'Parameter "fact" must be a non-empty string.'; @@ -195,17 +317,44 @@ export class MemoryTool extends BaseTool { } try { - // Use the static method with actual fs promises - await MemoryTool.performAddMemoryEntry(fact, getGlobalMemoryFilePath(), { - readFile: fs.readFile, - writeFile: fs.writeFile, - mkdir: fs.mkdir, - }); - const successMessage = `Okay, I've remembered that: "${fact}"`; - return { - llmContent: JSON.stringify({ success: true, message: successMessage }), - returnDisplay: successMessage, - }; + if (modified_by_user && modified_content !== undefined) { + // User modified the content in external editor, write it directly + await fs.mkdir(path.dirname(getGlobalMemoryFilePath()), { + recursive: true, + }); + await fs.writeFile( + getGlobalMemoryFilePath(), + modified_content, + 'utf-8', + ); + const successMessage = `Okay, I've updated the memory file with your modifications.`; + return { + llmContent: JSON.stringify({ + success: true, + message: successMessage, + }), + returnDisplay: successMessage, + }; + } else { + // Use the normal memory entry logic + await MemoryTool.performAddMemoryEntry( + fact, + getGlobalMemoryFilePath(), + { + readFile: fs.readFile, + writeFile: fs.writeFile, + mkdir: fs.mkdir, + }, + ); + const successMessage = `Okay, I've remembered that: "${fact}"`; + return { + llmContent: JSON.stringify({ + success: true, + message: successMessage, + }), + returnDisplay: successMessage, + }; + } } catch (error) { const errorMessage = error instanceof Error ? error.message : String(error); @@ -221,4 +370,25 @@ export class MemoryTool extends BaseTool { }; } } + + getModifyContext(_abortSignal: AbortSignal): ModifyContext { + return { + getFilePath: (_params: SaveMemoryParams) => getGlobalMemoryFilePath(), + getCurrentContent: async (_params: SaveMemoryParams): Promise => + this.readMemoryFileContent(), + getProposedContent: async (params: SaveMemoryParams): Promise => { + const currentContent = await this.readMemoryFileContent(); + return this.computeNewContent(currentContent, params.fact); + }, + createUpdatedParams: ( + _oldContent: string, + modifiedProposedContent: string, + originalParams: SaveMemoryParams, + ): SaveMemoryParams => ({ + ...originalParams, + modified_by_user: true, + modified_content: modifiedProposedContent, + }), + }; + } } -- cgit v1.2.3