diff options
| author | Olcan <[email protected]> | 2025-07-30 15:21:31 -0700 |
|---|---|---|
| committer | GitHub <[email protected]> | 2025-07-30 22:21:31 +0000 |
| commit | ac1bb5ee4275e508dfc2256bbd5ca012e4a4f469 (patch) | |
| tree | 0770166b2ce11d099536b4e2570ecd1830e208e0 /packages/core/src/tools/memoryTool.ts | |
| parent | 498edb57abc9c047e2bd1ea828cc591618745bc4 (diff) | |
confirm save_memory tool, with ability to see diff and edit manually for advanced changes that may override past memories (#5237)
Diffstat (limited to 'packages/core/src/tools/memoryTool.ts')
| -rw-r--r-- | packages/core/src/tools/memoryTool.ts | 198 |
1 files changed, 184 insertions, 14 deletions
diff --git a/packages/core/src/tools/memoryTool.ts b/packages/core/src/tools/memoryTool.ts index f0f1e16b..96509f79 100644 --- a/packages/core/src/tools/memoryTool.ts +++ b/packages/core/src/tools/memoryTool.ts @@ -4,11 +4,21 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { BaseTool, Icon, ToolResult } from './tools.js'; +import { + BaseTool, + ToolResult, + ToolEditConfirmationDetails, + ToolConfirmationOutcome, + Icon, +} from './tools.js'; import { FunctionDeclaration, Type } from '@google/genai'; import * as fs from 'fs/promises'; import * as path from 'path'; import { homedir } from 'os'; +import * as Diff from 'diff'; +import { DEFAULT_DIFF_OPTIONS } from './diffOptions.js'; +import { tildeifyPath } from '../utils/paths.js'; +import { ModifiableTool, ModifyContext } from './modifiable-tool.js'; const memoryToolSchemaData: FunctionDeclaration = { name: 'save_memory', @@ -80,6 +90,8 @@ export function getAllGeminiMdFilenames(): string[] { interface SaveMemoryParams { fact: string; + modified_by_user?: boolean; + modified_content?: string; } function getGlobalMemoryFilePath(): string { @@ -98,7 +110,12 @@ function ensureNewlineSeparation(currentContent: string): string { return '\n\n'; } -export class MemoryTool extends BaseTool<SaveMemoryParams, ToolResult> { +export class MemoryTool + extends BaseTool<SaveMemoryParams, ToolResult> + implements ModifiableTool<SaveMemoryParams> +{ + private static readonly allowlist: Set<string> = new Set(); + static readonly Name: string = memoryToolSchemaData.name!; constructor() { super( @@ -110,6 +127,111 @@ export class MemoryTool extends BaseTool<SaveMemoryParams, ToolResult> { ); } + getDescription(_params: SaveMemoryParams): string { + const memoryFilePath = getGlobalMemoryFilePath(); + return `in ${tildeifyPath(memoryFilePath)}`; + } + + /** + * Reads the current content of the memory file + */ + private async readMemoryFileContent(): Promise<string> { + try { + return await fs.readFile(getGlobalMemoryFilePath(), 'utf-8'); + } catch (err) { + const error = err as Error & { code?: string }; + if (!(error instanceof Error) || error.code !== 'ENOENT') throw err; + return ''; + } + } + + /** + * Computes the new content that would result from adding a memory entry + */ + private computeNewContent(currentContent: string, fact: string): string { + let processedText = fact.trim(); + processedText = processedText.replace(/^(-+\s*)+/, '').trim(); + const newMemoryItem = `- ${processedText}`; + + const headerIndex = currentContent.indexOf(MEMORY_SECTION_HEADER); + + if (headerIndex === -1) { + // Header not found, append header and then the entry + const separator = ensureNewlineSeparation(currentContent); + return ( + currentContent + + `${separator}${MEMORY_SECTION_HEADER}\n${newMemoryItem}\n` + ); + } else { + // Header found, find where to insert the new memory entry + const startOfSectionContent = headerIndex + MEMORY_SECTION_HEADER.length; + let endOfSectionIndex = currentContent.indexOf( + '\n## ', + startOfSectionContent, + ); + if (endOfSectionIndex === -1) { + endOfSectionIndex = currentContent.length; // End of file + } + + const beforeSectionMarker = currentContent + .substring(0, startOfSectionContent) + .trimEnd(); + let sectionContent = currentContent + .substring(startOfSectionContent, endOfSectionIndex) + .trimEnd(); + const afterSectionMarker = currentContent.substring(endOfSectionIndex); + + sectionContent += `\n${newMemoryItem}`; + return ( + `${beforeSectionMarker}\n${sectionContent.trimStart()}\n${afterSectionMarker}`.trimEnd() + + '\n' + ); + } + } + + async shouldConfirmExecute( + params: SaveMemoryParams, + _abortSignal: AbortSignal, + ): Promise<ToolEditConfirmationDetails | false> { + const memoryFilePath = getGlobalMemoryFilePath(); + const allowlistKey = memoryFilePath; + + if (MemoryTool.allowlist.has(allowlistKey)) { + return false; + } + + // Read current content of the memory file + const currentContent = await this.readMemoryFileContent(); + + // Calculate the new content that will be written to the memory file + const newContent = this.computeNewContent(currentContent, params.fact); + + const fileName = path.basename(memoryFilePath); + const fileDiff = Diff.createPatch( + fileName, + currentContent, + newContent, + 'Current', + 'Proposed', + DEFAULT_DIFF_OPTIONS, + ); + + const confirmationDetails: ToolEditConfirmationDetails = { + type: 'edit', + title: `Confirm Memory Save: ${tildeifyPath(memoryFilePath)}`, + fileName: memoryFilePath, + fileDiff, + originalContent: currentContent, + newContent, + onConfirm: async (outcome: ToolConfirmationOutcome) => { + if (outcome === ToolConfirmationOutcome.ProceedAlways) { + MemoryTool.allowlist.add(allowlistKey); + } + }, + }; + return confirmationDetails; + } + static async performAddMemoryEntry( text: string, memoryFilePath: string, @@ -184,7 +306,7 @@ export class MemoryTool extends BaseTool<SaveMemoryParams, ToolResult> { params: SaveMemoryParams, _signal: AbortSignal, ): Promise<ToolResult> { - const { fact } = params; + const { fact, modified_by_user, modified_content } = params; if (!fact || typeof fact !== 'string' || fact.trim() === '') { const errorMessage = 'Parameter "fact" must be a non-empty string.'; @@ -195,17 +317,44 @@ export class MemoryTool extends BaseTool<SaveMemoryParams, ToolResult> { } try { - // Use the static method with actual fs promises - await MemoryTool.performAddMemoryEntry(fact, getGlobalMemoryFilePath(), { - readFile: fs.readFile, - writeFile: fs.writeFile, - mkdir: fs.mkdir, - }); - const successMessage = `Okay, I've remembered that: "${fact}"`; - return { - llmContent: JSON.stringify({ success: true, message: successMessage }), - returnDisplay: successMessage, - }; + if (modified_by_user && modified_content !== undefined) { + // User modified the content in external editor, write it directly + await fs.mkdir(path.dirname(getGlobalMemoryFilePath()), { + recursive: true, + }); + await fs.writeFile( + getGlobalMemoryFilePath(), + modified_content, + 'utf-8', + ); + const successMessage = `Okay, I've updated the memory file with your modifications.`; + return { + llmContent: JSON.stringify({ + success: true, + message: successMessage, + }), + returnDisplay: successMessage, + }; + } else { + // Use the normal memory entry logic + await MemoryTool.performAddMemoryEntry( + fact, + getGlobalMemoryFilePath(), + { + readFile: fs.readFile, + writeFile: fs.writeFile, + mkdir: fs.mkdir, + }, + ); + const successMessage = `Okay, I've remembered that: "${fact}"`; + return { + llmContent: JSON.stringify({ + success: true, + message: successMessage, + }), + returnDisplay: successMessage, + }; + } } catch (error) { const errorMessage = error instanceof Error ? error.message : String(error); @@ -221,4 +370,25 @@ export class MemoryTool extends BaseTool<SaveMemoryParams, ToolResult> { }; } } + + getModifyContext(_abortSignal: AbortSignal): ModifyContext<SaveMemoryParams> { + return { + getFilePath: (_params: SaveMemoryParams) => getGlobalMemoryFilePath(), + getCurrentContent: async (_params: SaveMemoryParams): Promise<string> => + this.readMemoryFileContent(), + getProposedContent: async (params: SaveMemoryParams): Promise<string> => { + const currentContent = await this.readMemoryFileContent(); + return this.computeNewContent(currentContent, params.fact); + }, + createUpdatedParams: ( + _oldContent: string, + modifiedProposedContent: string, + originalParams: SaveMemoryParams, + ): SaveMemoryParams => ({ + ...originalParams, + modified_by_user: true, + modified_content: modifiedProposedContent, + }), + }; + } } |
