summaryrefslogtreecommitdiff
path: root/packages/core/src
diff options
context:
space:
mode:
Diffstat (limited to 'packages/core/src')
-rw-r--r--packages/core/src/tools/memoryTool.test.ts166
-rw-r--r--packages/core/src/tools/memoryTool.ts198
2 files changed, 343 insertions, 21 deletions
diff --git a/packages/core/src/tools/memoryTool.test.ts b/packages/core/src/tools/memoryTool.test.ts
index aff0cc2e..5a9b5f26 100644
--- a/packages/core/src/tools/memoryTool.test.ts
+++ b/packages/core/src/tools/memoryTool.test.ts
@@ -15,6 +15,7 @@ import {
import * as fs from 'fs/promises';
import * as path from 'path';
import * as os from 'os';
+import { ToolConfirmationOutcome } from './tools.js';
// Mock dependencies
vi.mock('fs/promises');
@@ -46,7 +47,7 @@ describe('MemoryTool', () => {
};
beforeEach(() => {
- vi.mocked(os.homedir).mockReturnValue('/mock/home');
+ vi.mocked(os.homedir).mockReturnValue(path.join('/mock', 'home'));
mockFsAdapter.readFile.mockReset();
mockFsAdapter.writeFile.mockReset().mockResolvedValue(undefined);
mockFsAdapter.mkdir
@@ -85,11 +86,15 @@ describe('MemoryTool', () => {
});
describe('performAddMemoryEntry (static method)', () => {
- const testFilePath = path.join(
- '/mock/home',
- '.gemini',
- DEFAULT_CONTEXT_FILENAME, // Use the default for basic tests
- );
+ let testFilePath: string;
+
+ beforeEach(() => {
+ testFilePath = path.join(
+ os.homedir(),
+ '.gemini',
+ DEFAULT_CONTEXT_FILENAME,
+ );
+ });
it('should create section and save a fact if file does not exist', async () => {
mockFsAdapter.readFile.mockRejectedValue({ code: 'ENOENT' }); // Simulate file not found
@@ -206,7 +211,7 @@ describe('MemoryTool', () => {
const result = await memoryTool.execute(params, mockAbortSignal);
// Use getCurrentGeminiMdFilename for the default expectation before any setGeminiMdFilename calls in a test
const expectedFilePath = path.join(
- '/mock/home',
+ os.homedir(),
'.gemini',
getCurrentGeminiMdFilename(), // This will be DEFAULT_CONTEXT_FILENAME unless changed by a test
);
@@ -262,4 +267,151 @@ describe('MemoryTool', () => {
);
});
});
+
+ describe('shouldConfirmExecute', () => {
+ let memoryTool: MemoryTool;
+
+ beforeEach(() => {
+ memoryTool = new MemoryTool();
+ // Clear the allowlist before each test
+ (MemoryTool as unknown as { allowlist: Set<string> }).allowlist.clear();
+ // Mock fs.readFile to return empty string (file doesn't exist)
+ vi.mocked(fs.readFile).mockResolvedValue('');
+ });
+
+ it('should return confirmation details when memory file is not allowlisted', async () => {
+ const params = { fact: 'Test fact' };
+ const result = await memoryTool.shouldConfirmExecute(
+ params,
+ mockAbortSignal,
+ );
+
+ expect(result).toBeDefined();
+ expect(result).not.toBe(false);
+
+ if (result && result.type === 'edit') {
+ const expectedPath = path.join('~', '.gemini', 'GEMINI.md');
+ expect(result.title).toBe(`Confirm Memory Save: ${expectedPath}`);
+ expect(result.fileName).toContain(path.join('mock', 'home', '.gemini'));
+ expect(result.fileName).toContain('GEMINI.md');
+ expect(result.fileDiff).toContain('Index: GEMINI.md');
+ expect(result.fileDiff).toContain('+## Gemini Added Memories');
+ expect(result.fileDiff).toContain('+- Test fact');
+ expect(result.originalContent).toBe('');
+ expect(result.newContent).toContain('## Gemini Added Memories');
+ expect(result.newContent).toContain('- Test fact');
+ }
+ });
+
+ it('should return false when memory file is already allowlisted', async () => {
+ const params = { fact: 'Test fact' };
+ const memoryFilePath = path.join(
+ os.homedir(),
+ '.gemini',
+ getCurrentGeminiMdFilename(),
+ );
+
+ // Add the memory file to the allowlist
+ (MemoryTool as unknown as { allowlist: Set<string> }).allowlist.add(
+ memoryFilePath,
+ );
+
+ const result = await memoryTool.shouldConfirmExecute(
+ params,
+ mockAbortSignal,
+ );
+
+ expect(result).toBe(false);
+ });
+
+ it('should add memory file to allowlist when ProceedAlways is confirmed', async () => {
+ const params = { fact: 'Test fact' };
+ const memoryFilePath = path.join(
+ os.homedir(),
+ '.gemini',
+ getCurrentGeminiMdFilename(),
+ );
+
+ const result = await memoryTool.shouldConfirmExecute(
+ params,
+ mockAbortSignal,
+ );
+
+ expect(result).toBeDefined();
+ expect(result).not.toBe(false);
+
+ if (result && result.type === 'edit') {
+ // Simulate the onConfirm callback
+ await result.onConfirm(ToolConfirmationOutcome.ProceedAlways);
+
+ // Check that the memory file was added to the allowlist
+ expect(
+ (MemoryTool as unknown as { allowlist: Set<string> }).allowlist.has(
+ memoryFilePath,
+ ),
+ ).toBe(true);
+ }
+ });
+
+ it('should not add memory file to allowlist when other outcomes are confirmed', async () => {
+ const params = { fact: 'Test fact' };
+ const memoryFilePath = path.join(
+ os.homedir(),
+ '.gemini',
+ getCurrentGeminiMdFilename(),
+ );
+
+ const result = await memoryTool.shouldConfirmExecute(
+ params,
+ mockAbortSignal,
+ );
+
+ expect(result).toBeDefined();
+ expect(result).not.toBe(false);
+
+ if (result && result.type === 'edit') {
+ // Simulate the onConfirm callback with different outcomes
+ await result.onConfirm(ToolConfirmationOutcome.ProceedOnce);
+ expect(
+ (MemoryTool as unknown as { allowlist: Set<string> }).allowlist.has(
+ memoryFilePath,
+ ),
+ ).toBe(false);
+
+ await result.onConfirm(ToolConfirmationOutcome.Cancel);
+ expect(
+ (MemoryTool as unknown as { allowlist: Set<string> }).allowlist.has(
+ memoryFilePath,
+ ),
+ ).toBe(false);
+ }
+ });
+
+ it('should handle existing memory file with content', async () => {
+ const params = { fact: 'New fact' };
+ const existingContent =
+ 'Some existing content.\n\n## Gemini Added Memories\n- Old fact\n';
+
+ // Mock fs.readFile to return existing content
+ vi.mocked(fs.readFile).mockResolvedValue(existingContent);
+
+ const result = await memoryTool.shouldConfirmExecute(
+ params,
+ mockAbortSignal,
+ );
+
+ expect(result).toBeDefined();
+ expect(result).not.toBe(false);
+
+ if (result && result.type === 'edit') {
+ const expectedPath = path.join('~', '.gemini', 'GEMINI.md');
+ expect(result.title).toBe(`Confirm Memory Save: ${expectedPath}`);
+ expect(result.fileDiff).toContain('Index: GEMINI.md');
+ expect(result.fileDiff).toContain('+- New fact');
+ expect(result.originalContent).toBe(existingContent);
+ expect(result.newContent).toContain('- Old fact');
+ expect(result.newContent).toContain('- New fact');
+ }
+ });
+ });
});
diff --git a/packages/core/src/tools/memoryTool.ts b/packages/core/src/tools/memoryTool.ts
index f0f1e16b..96509f79 100644
--- a/packages/core/src/tools/memoryTool.ts
+++ b/packages/core/src/tools/memoryTool.ts
@@ -4,11 +4,21 @@
* SPDX-License-Identifier: Apache-2.0
*/
-import { BaseTool, Icon, ToolResult } from './tools.js';
+import {
+ BaseTool,
+ ToolResult,
+ ToolEditConfirmationDetails,
+ ToolConfirmationOutcome,
+ Icon,
+} from './tools.js';
import { FunctionDeclaration, Type } from '@google/genai';
import * as fs from 'fs/promises';
import * as path from 'path';
import { homedir } from 'os';
+import * as Diff from 'diff';
+import { DEFAULT_DIFF_OPTIONS } from './diffOptions.js';
+import { tildeifyPath } from '../utils/paths.js';
+import { ModifiableTool, ModifyContext } from './modifiable-tool.js';
const memoryToolSchemaData: FunctionDeclaration = {
name: 'save_memory',
@@ -80,6 +90,8 @@ export function getAllGeminiMdFilenames(): string[] {
interface SaveMemoryParams {
fact: string;
+ modified_by_user?: boolean;
+ modified_content?: string;
}
function getGlobalMemoryFilePath(): string {
@@ -98,7 +110,12 @@ function ensureNewlineSeparation(currentContent: string): string {
return '\n\n';
}
-export class MemoryTool extends BaseTool<SaveMemoryParams, ToolResult> {
+export class MemoryTool
+ extends BaseTool<SaveMemoryParams, ToolResult>
+ implements ModifiableTool<SaveMemoryParams>
+{
+ private static readonly allowlist: Set<string> = new Set();
+
static readonly Name: string = memoryToolSchemaData.name!;
constructor() {
super(
@@ -110,6 +127,111 @@ export class MemoryTool extends BaseTool<SaveMemoryParams, ToolResult> {
);
}
+ getDescription(_params: SaveMemoryParams): string {
+ const memoryFilePath = getGlobalMemoryFilePath();
+ return `in ${tildeifyPath(memoryFilePath)}`;
+ }
+
+ /**
+ * Reads the current content of the memory file
+ */
+ private async readMemoryFileContent(): Promise<string> {
+ try {
+ return await fs.readFile(getGlobalMemoryFilePath(), 'utf-8');
+ } catch (err) {
+ const error = err as Error & { code?: string };
+ if (!(error instanceof Error) || error.code !== 'ENOENT') throw err;
+ return '';
+ }
+ }
+
+ /**
+ * Computes the new content that would result from adding a memory entry
+ */
+ private computeNewContent(currentContent: string, fact: string): string {
+ let processedText = fact.trim();
+ processedText = processedText.replace(/^(-+\s*)+/, '').trim();
+ const newMemoryItem = `- ${processedText}`;
+
+ const headerIndex = currentContent.indexOf(MEMORY_SECTION_HEADER);
+
+ if (headerIndex === -1) {
+ // Header not found, append header and then the entry
+ const separator = ensureNewlineSeparation(currentContent);
+ return (
+ currentContent +
+ `${separator}${MEMORY_SECTION_HEADER}\n${newMemoryItem}\n`
+ );
+ } else {
+ // Header found, find where to insert the new memory entry
+ const startOfSectionContent = headerIndex + MEMORY_SECTION_HEADER.length;
+ let endOfSectionIndex = currentContent.indexOf(
+ '\n## ',
+ startOfSectionContent,
+ );
+ if (endOfSectionIndex === -1) {
+ endOfSectionIndex = currentContent.length; // End of file
+ }
+
+ const beforeSectionMarker = currentContent
+ .substring(0, startOfSectionContent)
+ .trimEnd();
+ let sectionContent = currentContent
+ .substring(startOfSectionContent, endOfSectionIndex)
+ .trimEnd();
+ const afterSectionMarker = currentContent.substring(endOfSectionIndex);
+
+ sectionContent += `\n${newMemoryItem}`;
+ return (
+ `${beforeSectionMarker}\n${sectionContent.trimStart()}\n${afterSectionMarker}`.trimEnd() +
+ '\n'
+ );
+ }
+ }
+
+ async shouldConfirmExecute(
+ params: SaveMemoryParams,
+ _abortSignal: AbortSignal,
+ ): Promise<ToolEditConfirmationDetails | false> {
+ const memoryFilePath = getGlobalMemoryFilePath();
+ const allowlistKey = memoryFilePath;
+
+ if (MemoryTool.allowlist.has(allowlistKey)) {
+ return false;
+ }
+
+ // Read current content of the memory file
+ const currentContent = await this.readMemoryFileContent();
+
+ // Calculate the new content that will be written to the memory file
+ const newContent = this.computeNewContent(currentContent, params.fact);
+
+ const fileName = path.basename(memoryFilePath);
+ const fileDiff = Diff.createPatch(
+ fileName,
+ currentContent,
+ newContent,
+ 'Current',
+ 'Proposed',
+ DEFAULT_DIFF_OPTIONS,
+ );
+
+ const confirmationDetails: ToolEditConfirmationDetails = {
+ type: 'edit',
+ title: `Confirm Memory Save: ${tildeifyPath(memoryFilePath)}`,
+ fileName: memoryFilePath,
+ fileDiff,
+ originalContent: currentContent,
+ newContent,
+ onConfirm: async (outcome: ToolConfirmationOutcome) => {
+ if (outcome === ToolConfirmationOutcome.ProceedAlways) {
+ MemoryTool.allowlist.add(allowlistKey);
+ }
+ },
+ };
+ return confirmationDetails;
+ }
+
static async performAddMemoryEntry(
text: string,
memoryFilePath: string,
@@ -184,7 +306,7 @@ export class MemoryTool extends BaseTool<SaveMemoryParams, ToolResult> {
params: SaveMemoryParams,
_signal: AbortSignal,
): Promise<ToolResult> {
- const { fact } = params;
+ const { fact, modified_by_user, modified_content } = params;
if (!fact || typeof fact !== 'string' || fact.trim() === '') {
const errorMessage = 'Parameter "fact" must be a non-empty string.';
@@ -195,17 +317,44 @@ export class MemoryTool extends BaseTool<SaveMemoryParams, ToolResult> {
}
try {
- // Use the static method with actual fs promises
- await MemoryTool.performAddMemoryEntry(fact, getGlobalMemoryFilePath(), {
- readFile: fs.readFile,
- writeFile: fs.writeFile,
- mkdir: fs.mkdir,
- });
- const successMessage = `Okay, I've remembered that: "${fact}"`;
- return {
- llmContent: JSON.stringify({ success: true, message: successMessage }),
- returnDisplay: successMessage,
- };
+ if (modified_by_user && modified_content !== undefined) {
+ // User modified the content in external editor, write it directly
+ await fs.mkdir(path.dirname(getGlobalMemoryFilePath()), {
+ recursive: true,
+ });
+ await fs.writeFile(
+ getGlobalMemoryFilePath(),
+ modified_content,
+ 'utf-8',
+ );
+ const successMessage = `Okay, I've updated the memory file with your modifications.`;
+ return {
+ llmContent: JSON.stringify({
+ success: true,
+ message: successMessage,
+ }),
+ returnDisplay: successMessage,
+ };
+ } else {
+ // Use the normal memory entry logic
+ await MemoryTool.performAddMemoryEntry(
+ fact,
+ getGlobalMemoryFilePath(),
+ {
+ readFile: fs.readFile,
+ writeFile: fs.writeFile,
+ mkdir: fs.mkdir,
+ },
+ );
+ const successMessage = `Okay, I've remembered that: "${fact}"`;
+ return {
+ llmContent: JSON.stringify({
+ success: true,
+ message: successMessage,
+ }),
+ returnDisplay: successMessage,
+ };
+ }
} catch (error) {
const errorMessage =
error instanceof Error ? error.message : String(error);
@@ -221,4 +370,25 @@ export class MemoryTool extends BaseTool<SaveMemoryParams, ToolResult> {
};
}
}
+
+ getModifyContext(_abortSignal: AbortSignal): ModifyContext<SaveMemoryParams> {
+ return {
+ getFilePath: (_params: SaveMemoryParams) => getGlobalMemoryFilePath(),
+ getCurrentContent: async (_params: SaveMemoryParams): Promise<string> =>
+ this.readMemoryFileContent(),
+ getProposedContent: async (params: SaveMemoryParams): Promise<string> => {
+ const currentContent = await this.readMemoryFileContent();
+ return this.computeNewContent(currentContent, params.fact);
+ },
+ createUpdatedParams: (
+ _oldContent: string,
+ modifiedProposedContent: string,
+ originalParams: SaveMemoryParams,
+ ): SaveMemoryParams => ({
+ ...originalParams,
+ modified_by_user: true,
+ modified_content: modifiedProposedContent,
+ }),
+ };
+ }
}