summaryrefslogtreecommitdiff
path: root/packages/core/src/tools/memoryTool.ts
diff options
context:
space:
mode:
authorOlcan <[email protected]>2025-07-30 15:21:31 -0700
committerGitHub <[email protected]>2025-07-30 22:21:31 +0000
commitac1bb5ee4275e508dfc2256bbd5ca012e4a4f469 (patch)
tree0770166b2ce11d099536b4e2570ecd1830e208e0 /packages/core/src/tools/memoryTool.ts
parent498edb57abc9c047e2bd1ea828cc591618745bc4 (diff)
confirm save_memory tool, with ability to see diff and edit manually for advanced changes that may override past memories (#5237)
Diffstat (limited to 'packages/core/src/tools/memoryTool.ts')
-rw-r--r--packages/core/src/tools/memoryTool.ts198
1 files changed, 184 insertions, 14 deletions
diff --git a/packages/core/src/tools/memoryTool.ts b/packages/core/src/tools/memoryTool.ts
index f0f1e16b..96509f79 100644
--- a/packages/core/src/tools/memoryTool.ts
+++ b/packages/core/src/tools/memoryTool.ts
@@ -4,11 +4,21 @@
* SPDX-License-Identifier: Apache-2.0
*/
-import { BaseTool, Icon, ToolResult } from './tools.js';
+import {
+ BaseTool,
+ ToolResult,
+ ToolEditConfirmationDetails,
+ ToolConfirmationOutcome,
+ Icon,
+} from './tools.js';
import { FunctionDeclaration, Type } from '@google/genai';
import * as fs from 'fs/promises';
import * as path from 'path';
import { homedir } from 'os';
+import * as Diff from 'diff';
+import { DEFAULT_DIFF_OPTIONS } from './diffOptions.js';
+import { tildeifyPath } from '../utils/paths.js';
+import { ModifiableTool, ModifyContext } from './modifiable-tool.js';
const memoryToolSchemaData: FunctionDeclaration = {
name: 'save_memory',
@@ -80,6 +90,8 @@ export function getAllGeminiMdFilenames(): string[] {
interface SaveMemoryParams {
fact: string;
+ modified_by_user?: boolean;
+ modified_content?: string;
}
function getGlobalMemoryFilePath(): string {
@@ -98,7 +110,12 @@ function ensureNewlineSeparation(currentContent: string): string {
return '\n\n';
}
-export class MemoryTool extends BaseTool<SaveMemoryParams, ToolResult> {
+export class MemoryTool
+ extends BaseTool<SaveMemoryParams, ToolResult>
+ implements ModifiableTool<SaveMemoryParams>
+{
+ private static readonly allowlist: Set<string> = new Set();
+
static readonly Name: string = memoryToolSchemaData.name!;
constructor() {
super(
@@ -110,6 +127,111 @@ export class MemoryTool extends BaseTool<SaveMemoryParams, ToolResult> {
);
}
+ getDescription(_params: SaveMemoryParams): string {
+ const memoryFilePath = getGlobalMemoryFilePath();
+ return `in ${tildeifyPath(memoryFilePath)}`;
+ }
+
+ /**
+ * Reads the current content of the memory file
+ */
+ private async readMemoryFileContent(): Promise<string> {
+ try {
+ return await fs.readFile(getGlobalMemoryFilePath(), 'utf-8');
+ } catch (err) {
+ const error = err as Error & { code?: string };
+ if (!(error instanceof Error) || error.code !== 'ENOENT') throw err;
+ return '';
+ }
+ }
+
+ /**
+ * Computes the new content that would result from adding a memory entry
+ */
+ private computeNewContent(currentContent: string, fact: string): string {
+ let processedText = fact.trim();
+ processedText = processedText.replace(/^(-+\s*)+/, '').trim();
+ const newMemoryItem = `- ${processedText}`;
+
+ const headerIndex = currentContent.indexOf(MEMORY_SECTION_HEADER);
+
+ if (headerIndex === -1) {
+ // Header not found, append header and then the entry
+ const separator = ensureNewlineSeparation(currentContent);
+ return (
+ currentContent +
+ `${separator}${MEMORY_SECTION_HEADER}\n${newMemoryItem}\n`
+ );
+ } else {
+ // Header found, find where to insert the new memory entry
+ const startOfSectionContent = headerIndex + MEMORY_SECTION_HEADER.length;
+ let endOfSectionIndex = currentContent.indexOf(
+ '\n## ',
+ startOfSectionContent,
+ );
+ if (endOfSectionIndex === -1) {
+ endOfSectionIndex = currentContent.length; // End of file
+ }
+
+ const beforeSectionMarker = currentContent
+ .substring(0, startOfSectionContent)
+ .trimEnd();
+ let sectionContent = currentContent
+ .substring(startOfSectionContent, endOfSectionIndex)
+ .trimEnd();
+ const afterSectionMarker = currentContent.substring(endOfSectionIndex);
+
+ sectionContent += `\n${newMemoryItem}`;
+ return (
+ `${beforeSectionMarker}\n${sectionContent.trimStart()}\n${afterSectionMarker}`.trimEnd() +
+ '\n'
+ );
+ }
+ }
+
+ async shouldConfirmExecute(
+ params: SaveMemoryParams,
+ _abortSignal: AbortSignal,
+ ): Promise<ToolEditConfirmationDetails | false> {
+ const memoryFilePath = getGlobalMemoryFilePath();
+ const allowlistKey = memoryFilePath;
+
+ if (MemoryTool.allowlist.has(allowlistKey)) {
+ return false;
+ }
+
+ // Read current content of the memory file
+ const currentContent = await this.readMemoryFileContent();
+
+ // Calculate the new content that will be written to the memory file
+ const newContent = this.computeNewContent(currentContent, params.fact);
+
+ const fileName = path.basename(memoryFilePath);
+ const fileDiff = Diff.createPatch(
+ fileName,
+ currentContent,
+ newContent,
+ 'Current',
+ 'Proposed',
+ DEFAULT_DIFF_OPTIONS,
+ );
+
+ const confirmationDetails: ToolEditConfirmationDetails = {
+ type: 'edit',
+ title: `Confirm Memory Save: ${tildeifyPath(memoryFilePath)}`,
+ fileName: memoryFilePath,
+ fileDiff,
+ originalContent: currentContent,
+ newContent,
+ onConfirm: async (outcome: ToolConfirmationOutcome) => {
+ if (outcome === ToolConfirmationOutcome.ProceedAlways) {
+ MemoryTool.allowlist.add(allowlistKey);
+ }
+ },
+ };
+ return confirmationDetails;
+ }
+
static async performAddMemoryEntry(
text: string,
memoryFilePath: string,
@@ -184,7 +306,7 @@ export class MemoryTool extends BaseTool<SaveMemoryParams, ToolResult> {
params: SaveMemoryParams,
_signal: AbortSignal,
): Promise<ToolResult> {
- const { fact } = params;
+ const { fact, modified_by_user, modified_content } = params;
if (!fact || typeof fact !== 'string' || fact.trim() === '') {
const errorMessage = 'Parameter "fact" must be a non-empty string.';
@@ -195,17 +317,44 @@ export class MemoryTool extends BaseTool<SaveMemoryParams, ToolResult> {
}
try {
- // Use the static method with actual fs promises
- await MemoryTool.performAddMemoryEntry(fact, getGlobalMemoryFilePath(), {
- readFile: fs.readFile,
- writeFile: fs.writeFile,
- mkdir: fs.mkdir,
- });
- const successMessage = `Okay, I've remembered that: "${fact}"`;
- return {
- llmContent: JSON.stringify({ success: true, message: successMessage }),
- returnDisplay: successMessage,
- };
+ if (modified_by_user && modified_content !== undefined) {
+ // User modified the content in external editor, write it directly
+ await fs.mkdir(path.dirname(getGlobalMemoryFilePath()), {
+ recursive: true,
+ });
+ await fs.writeFile(
+ getGlobalMemoryFilePath(),
+ modified_content,
+ 'utf-8',
+ );
+ const successMessage = `Okay, I've updated the memory file with your modifications.`;
+ return {
+ llmContent: JSON.stringify({
+ success: true,
+ message: successMessage,
+ }),
+ returnDisplay: successMessage,
+ };
+ } else {
+ // Use the normal memory entry logic
+ await MemoryTool.performAddMemoryEntry(
+ fact,
+ getGlobalMemoryFilePath(),
+ {
+ readFile: fs.readFile,
+ writeFile: fs.writeFile,
+ mkdir: fs.mkdir,
+ },
+ );
+ const successMessage = `Okay, I've remembered that: "${fact}"`;
+ return {
+ llmContent: JSON.stringify({
+ success: true,
+ message: successMessage,
+ }),
+ returnDisplay: successMessage,
+ };
+ }
} catch (error) {
const errorMessage =
error instanceof Error ? error.message : String(error);
@@ -221,4 +370,25 @@ export class MemoryTool extends BaseTool<SaveMemoryParams, ToolResult> {
};
}
}
+
+ getModifyContext(_abortSignal: AbortSignal): ModifyContext<SaveMemoryParams> {
+ return {
+ getFilePath: (_params: SaveMemoryParams) => getGlobalMemoryFilePath(),
+ getCurrentContent: async (_params: SaveMemoryParams): Promise<string> =>
+ this.readMemoryFileContent(),
+ getProposedContent: async (params: SaveMemoryParams): Promise<string> => {
+ const currentContent = await this.readMemoryFileContent();
+ return this.computeNewContent(currentContent, params.fact);
+ },
+ createUpdatedParams: (
+ _oldContent: string,
+ modifiedProposedContent: string,
+ originalParams: SaveMemoryParams,
+ ): SaveMemoryParams => ({
+ ...originalParams,
+ modified_by_user: true,
+ modified_content: modifiedProposedContent,
+ }),
+ };
+ }
}