summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAbhi <[email protected]>2025-07-17 19:23:17 -0400
committerGitHub <[email protected]>2025-07-17 23:23:17 +0000
commit5df6c9fb660f932d54d6b5d1080cb86c95a824cf (patch)
tree82ef85725cba809278e2f2eef7395ea7458e84ff
parentf0dc9690b7903532099a5a2c4d98e02b3d2382bf (diff)
migrate restore command (#4388)
-rw-r--r--packages/cli/src/services/CommandService.test.ts20
-rw-r--r--packages/cli/src/services/CommandService.ts2
-rw-r--r--packages/cli/src/test-utils/mockCommandContext.ts1
-rw-r--r--packages/cli/src/ui/commands/restoreCommand.test.ts237
-rw-r--r--packages/cli/src/ui/commands/restoreCommand.ts155
-rw-r--r--packages/cli/src/ui/commands/types.ts6
-rw-r--r--packages/cli/src/ui/hooks/slashCommandProcessor.ts126
7 files changed, 424 insertions, 123 deletions
diff --git a/packages/cli/src/services/CommandService.test.ts b/packages/cli/src/services/CommandService.test.ts
index 084f603b..6ae52b52 100644
--- a/packages/cli/src/services/CommandService.test.ts
+++ b/packages/cli/src/services/CommandService.test.ts
@@ -26,6 +26,7 @@ import { mcpCommand } from '../ui/commands/mcpCommand.js';
import { editorCommand } from '../ui/commands/editorCommand.js';
import { bugCommand } from '../ui/commands/bugCommand.js';
import { quitCommand } from '../ui/commands/quitCommand.js';
+import { restoreCommand } from '../ui/commands/restoreCommand.js';
// Mock the command modules to isolate the service from the command implementations.
vi.mock('../ui/commands/memoryCommand.js', () => ({
@@ -79,6 +80,9 @@ vi.mock('../ui/commands/bugCommand.js', () => ({
vi.mock('../ui/commands/quitCommand.js', () => ({
quitCommand: { name: 'quit', description: 'Mock Quit' },
}));
+vi.mock('../ui/commands/restoreCommand.js', () => ({
+ restoreCommand: vi.fn(),
+}));
describe('CommandService', () => {
const subCommandLen = 17;
@@ -87,8 +91,10 @@ describe('CommandService', () => {
beforeEach(() => {
mockConfig = {
getIdeMode: vi.fn(),
+ getCheckpointingEnabled: vi.fn(),
} as unknown as Mocked<Config>;
vi.mocked(ideCommand).mockReturnValue(null);
+ vi.mocked(restoreCommand).mockReturnValue(null);
});
describe('when using default production loader', () => {
@@ -151,6 +157,20 @@ describe('CommandService', () => {
expect(commandNames).toContain('quit');
});
+ it('should include restore command when checkpointing is on', async () => {
+ mockConfig.getCheckpointingEnabled.mockReturnValue(true);
+ vi.mocked(restoreCommand).mockReturnValue({
+ name: 'restore',
+ description: 'Mock Restore',
+ });
+ await commandService.loadCommands();
+ const tree = commandService.getCommands();
+
+ expect(tree.length).toBe(subCommandLen + 1);
+ const commandNames = tree.map((cmd) => cmd.name);
+ expect(commandNames).toContain('restore');
+ });
+
it('should overwrite any existing commands when called again', async () => {
// Load once
await commandService.loadCommands();
diff --git a/packages/cli/src/services/CommandService.ts b/packages/cli/src/services/CommandService.ts
index 773f5b31..611b0a7b 100644
--- a/packages/cli/src/services/CommandService.ts
+++ b/packages/cli/src/services/CommandService.ts
@@ -24,6 +24,7 @@ import { compressCommand } from '../ui/commands/compressCommand.js';
import { ideCommand } from '../ui/commands/ideCommand.js';
import { bugCommand } from '../ui/commands/bugCommand.js';
import { quitCommand } from '../ui/commands/quitCommand.js';
+import { restoreCommand } from '../ui/commands/restoreCommand.js';
const loadBuiltInCommands = async (
config: Config | null,
@@ -44,6 +45,7 @@ const loadBuiltInCommands = async (
memoryCommand,
privacyCommand,
quitCommand,
+ restoreCommand(config),
statsCommand,
themeCommand,
toolsCommand,
diff --git a/packages/cli/src/test-utils/mockCommandContext.ts b/packages/cli/src/test-utils/mockCommandContext.ts
index 899d5747..3fb33b3f 100644
--- a/packages/cli/src/test-utils/mockCommandContext.ts
+++ b/packages/cli/src/test-utils/mockCommandContext.ts
@@ -46,6 +46,7 @@ export const createMockCommandContext = (
setDebugMessage: vi.fn(),
pendingItem: null,
setPendingItem: vi.fn(),
+ loadHistory: vi.fn(),
},
session: {
stats: {
diff --git a/packages/cli/src/ui/commands/restoreCommand.test.ts b/packages/cli/src/ui/commands/restoreCommand.test.ts
new file mode 100644
index 00000000..53cd7d18
--- /dev/null
+++ b/packages/cli/src/ui/commands/restoreCommand.test.ts
@@ -0,0 +1,237 @@
+/**
+ * @license
+ * Copyright 2025 Google LLC
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+import {
+ vi,
+ describe,
+ it,
+ expect,
+ beforeEach,
+ afterEach,
+ Mocked,
+ Mock,
+} from 'vitest';
+import * as fs from 'fs/promises';
+import { restoreCommand } from './restoreCommand.js';
+import { type CommandContext } from './types.js';
+import { createMockCommandContext } from '../../test-utils/mockCommandContext.js';
+import { Config, GitService } from '@google/gemini-cli-core';
+
+vi.mock('fs/promises', () => ({
+ readdir: vi.fn(),
+ readFile: vi.fn(),
+ mkdir: vi.fn(),
+}));
+
+describe('restoreCommand', () => {
+ let mockContext: CommandContext;
+ let mockConfig: Config;
+ let mockGitService: GitService;
+ const mockFsPromises = fs as Mocked<typeof fs>;
+ let mockSetHistory: ReturnType<typeof vi.fn>;
+
+ beforeEach(() => {
+ mockSetHistory = vi.fn().mockResolvedValue(undefined);
+ mockGitService = {
+ restoreProjectFromSnapshot: vi.fn().mockResolvedValue(undefined),
+ } as unknown as GitService;
+
+ mockConfig = {
+ getCheckpointingEnabled: vi.fn().mockReturnValue(true),
+ getProjectTempDir: vi.fn().mockReturnValue('/tmp/gemini'),
+ getGeminiClient: vi.fn().mockReturnValue({
+ setHistory: mockSetHistory,
+ }),
+ } as unknown as Config;
+
+ mockContext = createMockCommandContext({
+ services: {
+ config: mockConfig,
+ git: mockGitService,
+ },
+ });
+ });
+
+ afterEach(() => {
+ vi.restoreAllMocks();
+ });
+
+ it('should return null if checkpointing is not enabled', () => {
+ (mockConfig.getCheckpointingEnabled as Mock).mockReturnValue(false);
+ const command = restoreCommand(mockConfig);
+ expect(command).toBeNull();
+ });
+
+ it('should return the command if checkpointing is enabled', () => {
+ const command = restoreCommand(mockConfig);
+ expect(command).not.toBeNull();
+ expect(command?.name).toBe('restore');
+ expect(command?.description).toBeDefined();
+ expect(command?.action).toBeDefined();
+ expect(command?.completion).toBeDefined();
+ });
+
+ describe('action', () => {
+ it('should return an error if temp dir is not found', async () => {
+ (mockConfig.getProjectTempDir as Mock).mockReturnValue(undefined);
+ const command = restoreCommand(mockConfig);
+ const result = await command?.action?.(mockContext, '');
+ expect(result).toEqual({
+ type: 'message',
+ messageType: 'error',
+ content: 'Could not determine the .gemini directory path.',
+ });
+ });
+
+ it('should inform when no checkpoints are found if no args are passed', async () => {
+ mockFsPromises.readdir.mockResolvedValue([]);
+ const command = restoreCommand(mockConfig);
+ const result = await command?.action?.(mockContext, '');
+ expect(result).toEqual({
+ type: 'message',
+ messageType: 'info',
+ content: 'No restorable tool calls found.',
+ });
+ expect(mockFsPromises.mkdir).toHaveBeenCalledWith(
+ '/tmp/gemini/checkpoints',
+ {
+ recursive: true,
+ },
+ );
+ });
+
+ it('should list available checkpoints if no args are passed', async () => {
+ mockFsPromises.readdir.mockResolvedValue([
+ 'test1.json',
+ 'test2.json',
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
+ ] as any);
+ const command = restoreCommand(mockConfig);
+ const result = await command?.action?.(mockContext, '');
+ expect(result).toEqual({
+ type: 'message',
+ messageType: 'info',
+ content: 'Available tool calls to restore:\n\ntest1\ntest2',
+ });
+ });
+
+ it('should return an error if the specified file is not found', async () => {
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
+ mockFsPromises.readdir.mockResolvedValue(['test1.json'] as any);
+ const command = restoreCommand(mockConfig);
+ const result = await command?.action?.(mockContext, 'test2');
+ expect(result).toEqual({
+ type: 'message',
+ messageType: 'error',
+ content: 'File not found: test2.json',
+ });
+ });
+
+ it('should handle file read errors gracefully', async () => {
+ const readError = new Error('Read failed');
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
+ mockFsPromises.readdir.mockResolvedValue(['test1.json'] as any);
+ mockFsPromises.readFile.mockRejectedValue(readError);
+ const command = restoreCommand(mockConfig);
+ const result = await command?.action?.(mockContext, 'test1');
+ expect(result).toEqual({
+ type: 'message',
+ messageType: 'error',
+ content: `Could not read restorable tool calls. This is the error: ${readError}`,
+ });
+ });
+
+ it('should restore a tool call and project state', async () => {
+ const toolCallData = {
+ history: [{ type: 'user', text: 'do a thing' }],
+ clientHistory: [{ role: 'user', parts: [{ text: 'do a thing' }] }],
+ commitHash: 'abcdef123',
+ toolCall: { name: 'run_shell_command', args: 'ls' },
+ };
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
+ mockFsPromises.readdir.mockResolvedValue(['my-checkpoint.json'] as any);
+ mockFsPromises.readFile.mockResolvedValue(JSON.stringify(toolCallData));
+
+ const command = restoreCommand(mockConfig);
+ const result = await command?.action?.(mockContext, 'my-checkpoint');
+
+ // Check history restoration
+ expect(mockContext.ui.loadHistory).toHaveBeenCalledWith(
+ toolCallData.history,
+ );
+ expect(mockSetHistory).toHaveBeenCalledWith(toolCallData.clientHistory);
+
+ // Check git restoration
+ expect(mockGitService.restoreProjectFromSnapshot).toHaveBeenCalledWith(
+ toolCallData.commitHash,
+ );
+ expect(mockContext.ui.addItem).toHaveBeenCalledWith(
+ {
+ type: 'info',
+ text: 'Restored project to the state before the tool call.',
+ },
+ expect.any(Number),
+ );
+
+ // Check returned action
+ expect(result).toEqual({
+ type: 'tool',
+ toolName: 'run_shell_command',
+ toolArgs: 'ls',
+ });
+ });
+
+ it('should restore even if only toolCall is present', async () => {
+ const toolCallData = {
+ toolCall: { name: 'run_shell_command', args: 'ls' },
+ };
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
+ mockFsPromises.readdir.mockResolvedValue(['my-checkpoint.json'] as any);
+ mockFsPromises.readFile.mockResolvedValue(JSON.stringify(toolCallData));
+
+ const command = restoreCommand(mockConfig);
+ const result = await command?.action?.(mockContext, 'my-checkpoint');
+
+ expect(mockContext.ui.loadHistory).not.toHaveBeenCalled();
+ expect(mockSetHistory).not.toHaveBeenCalled();
+ expect(mockGitService.restoreProjectFromSnapshot).not.toHaveBeenCalled();
+
+ expect(result).toEqual({
+ type: 'tool',
+ toolName: 'run_shell_command',
+ toolArgs: 'ls',
+ });
+ });
+ });
+
+ describe('completion', () => {
+ it('should return an empty array if temp dir is not found', async () => {
+ (mockConfig.getProjectTempDir as Mock).mockReturnValue(undefined);
+ const command = restoreCommand(mockConfig);
+ const result = await command?.completion?.(mockContext, '');
+ expect(result).toEqual([]);
+ });
+
+ it('should return an empty array on readdir error', async () => {
+ mockFsPromises.readdir.mockRejectedValue(new Error('ENOENT'));
+ const command = restoreCommand(mockConfig);
+ const result = await command?.completion?.(mockContext, '');
+ expect(result).toEqual([]);
+ });
+
+ it('should return a list of checkpoint names', async () => {
+ mockFsPromises.readdir.mockResolvedValue([
+ 'test1.json',
+ 'test2.json',
+ 'not-a-checkpoint.txt',
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
+ ] as any);
+ const command = restoreCommand(mockConfig);
+ const result = await command?.completion?.(mockContext, '');
+ expect(result).toEqual(['test1', 'test2']);
+ });
+ });
+});
diff --git a/packages/cli/src/ui/commands/restoreCommand.ts b/packages/cli/src/ui/commands/restoreCommand.ts
new file mode 100644
index 00000000..3d744189
--- /dev/null
+++ b/packages/cli/src/ui/commands/restoreCommand.ts
@@ -0,0 +1,155 @@
+/**
+ * @license
+ * Copyright 2025 Google LLC
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+import * as fs from 'fs/promises';
+import path from 'path';
+import {
+ type CommandContext,
+ type SlashCommand,
+ type SlashCommandActionReturn,
+} from './types.js';
+import { Config } from '@google/gemini-cli-core';
+
+async function restoreAction(
+ context: CommandContext,
+ args: string,
+): Promise<void | SlashCommandActionReturn> {
+ const { services, ui } = context;
+ const { config, git: gitService } = services;
+ const { addItem, loadHistory } = ui;
+
+ const checkpointDir = config?.getProjectTempDir()
+ ? path.join(config.getProjectTempDir(), 'checkpoints')
+ : undefined;
+
+ if (!checkpointDir) {
+ return {
+ type: 'message',
+ messageType: 'error',
+ content: 'Could not determine the .gemini directory path.',
+ };
+ }
+
+ try {
+ // Ensure the directory exists before trying to read it.
+ await fs.mkdir(checkpointDir, { recursive: true });
+ const files = await fs.readdir(checkpointDir);
+ const jsonFiles = files.filter((file) => file.endsWith('.json'));
+
+ if (!args) {
+ if (jsonFiles.length === 0) {
+ return {
+ type: 'message',
+ messageType: 'info',
+ content: 'No restorable tool calls found.',
+ };
+ }
+ const truncatedFiles = jsonFiles.map((file) => {
+ const components = file.split('.');
+ if (components.length <= 1) {
+ return file;
+ }
+ components.pop();
+ return components.join('.');
+ });
+ const fileList = truncatedFiles.join('\n');
+ return {
+ type: 'message',
+ messageType: 'info',
+ content: `Available tool calls to restore:\n\n${fileList}`,
+ };
+ }
+
+ const selectedFile = args.endsWith('.json') ? args : `${args}.json`;
+
+ if (!jsonFiles.includes(selectedFile)) {
+ return {
+ type: 'message',
+ messageType: 'error',
+ content: `File not found: ${selectedFile}`,
+ };
+ }
+
+ const filePath = path.join(checkpointDir, selectedFile);
+ const data = await fs.readFile(filePath, 'utf-8');
+ const toolCallData = JSON.parse(data);
+
+ if (toolCallData.history) {
+ if (!loadHistory) {
+ // This should not happen
+ return {
+ type: 'message',
+ messageType: 'error',
+ content: 'loadHistory function is not available.',
+ };
+ }
+ loadHistory(toolCallData.history);
+ }
+
+ if (toolCallData.clientHistory) {
+ await config?.getGeminiClient()?.setHistory(toolCallData.clientHistory);
+ }
+
+ if (toolCallData.commitHash) {
+ await gitService?.restoreProjectFromSnapshot(toolCallData.commitHash);
+ addItem(
+ {
+ type: 'info',
+ text: 'Restored project to the state before the tool call.',
+ },
+ Date.now(),
+ );
+ }
+
+ return {
+ type: 'tool',
+ toolName: toolCallData.toolCall.name,
+ toolArgs: toolCallData.toolCall.args,
+ };
+ } catch (error) {
+ return {
+ type: 'message',
+ messageType: 'error',
+ content: `Could not read restorable tool calls. This is the error: ${error}`,
+ };
+ }
+}
+
+async function completion(
+ context: CommandContext,
+ _partialArg: string,
+): Promise<string[]> {
+ const { services } = context;
+ const { config } = services;
+ const checkpointDir = config?.getProjectTempDir()
+ ? path.join(config.getProjectTempDir(), 'checkpoints')
+ : undefined;
+ if (!checkpointDir) {
+ return [];
+ }
+ try {
+ const files = await fs.readdir(checkpointDir);
+ return files
+ .filter((file) => file.endsWith('.json'))
+ .map((file) => file.replace('.json', ''));
+ } catch (_err) {
+ return [];
+ }
+}
+
+export const restoreCommand = (config: Config | null): SlashCommand | null => {
+ if (!config?.getCheckpointingEnabled()) {
+ return null;
+ }
+
+ return {
+ name: 'restore',
+ description:
+ 'Restore a tool call. This will reset the conversation and file history to the state it was in when the tool call was suggested',
+ action: restoreAction,
+ completion,
+ };
+};
diff --git a/packages/cli/src/ui/commands/types.ts b/packages/cli/src/ui/commands/types.ts
index d3d5ee8a..51b66fb4 100644
--- a/packages/cli/src/ui/commands/types.ts
+++ b/packages/cli/src/ui/commands/types.ts
@@ -41,6 +41,12 @@ export interface CommandContext {
* @param item The history item to display as pending, or `null` to clear.
*/
setPendingItem: (item: HistoryItemWithoutId | null) => void;
+ /**
+ * Loads a new set of history items, replacing the current history.
+ *
+ * @param history The array of history items to load.
+ */
+ loadHistory: UseHistoryManagerReturn['loadHistory'];
};
// Session-specific data
session: {
diff --git a/packages/cli/src/ui/hooks/slashCommandProcessor.ts b/packages/cli/src/ui/hooks/slashCommandProcessor.ts
index 125d051e..295d1c50 100644
--- a/packages/cli/src/ui/hooks/slashCommandProcessor.ts
+++ b/packages/cli/src/ui/hooks/slashCommandProcessor.ts
@@ -18,8 +18,6 @@ import {
HistoryItem,
SlashCommandProcessorResult,
} from '../types.js';
-import { promises as fs } from 'fs';
-import path from 'path';
import { LoadedSettings } from '../../config/settings.js';
import {
type CommandContext,
@@ -155,6 +153,7 @@ export const useSlashCommandProcessor = (
console.clear();
refreshStatic();
},
+ loadHistory,
setDebugMessage: onDebugMessage,
pendingItem: pendingCompressionItemRef.current,
setPendingItem: setPendingCompressionItem,
@@ -168,6 +167,7 @@ export const useSlashCommandProcessor = (
settings,
gitService,
logger,
+ loadHistory,
addItem,
clearItems,
refreshStatic,
@@ -203,128 +203,8 @@ export const useSlashCommandProcessor = (
},
];
- if (config?.getCheckpointingEnabled()) {
- commands.push({
- name: 'restore',
- description:
- 'restore a tool call. This will reset the conversation and file history to the state it was in when the tool call was suggested',
- completion: async () => {
- const checkpointDir = config?.getProjectTempDir()
- ? path.join(config.getProjectTempDir(), 'checkpoints')
- : undefined;
- if (!checkpointDir) {
- return [];
- }
- try {
- const files = await fs.readdir(checkpointDir);
- return files
- .filter((file) => file.endsWith('.json'))
- .map((file) => file.replace('.json', ''));
- } catch (_err) {
- return [];
- }
- },
- action: async (_mainCommand, subCommand, _args) => {
- const checkpointDir = config?.getProjectTempDir()
- ? path.join(config.getProjectTempDir(), 'checkpoints')
- : undefined;
-
- if (!checkpointDir) {
- addMessage({
- type: MessageType.ERROR,
- content: 'Could not determine the .gemini directory path.',
- timestamp: new Date(),
- });
- return;
- }
-
- try {
- // Ensure the directory exists before trying to read it.
- await fs.mkdir(checkpointDir, { recursive: true });
- const files = await fs.readdir(checkpointDir);
- const jsonFiles = files.filter((file) => file.endsWith('.json'));
-
- if (!subCommand) {
- if (jsonFiles.length === 0) {
- addMessage({
- type: MessageType.INFO,
- content: 'No restorable tool calls found.',
- timestamp: new Date(),
- });
- return;
- }
- const truncatedFiles = jsonFiles.map((file) => {
- const components = file.split('.');
- if (components.length <= 1) {
- return file;
- }
- components.pop();
- return components.join('.');
- });
- const fileList = truncatedFiles.join('\n');
- addMessage({
- type: MessageType.INFO,
- content: `Available tool calls to restore:\n\n${fileList}`,
- timestamp: new Date(),
- });
- return;
- }
-
- const selectedFile = subCommand.endsWith('.json')
- ? subCommand
- : `${subCommand}.json`;
-
- if (!jsonFiles.includes(selectedFile)) {
- addMessage({
- type: MessageType.ERROR,
- content: `File not found: ${selectedFile}`,
- timestamp: new Date(),
- });
- return;
- }
-
- const filePath = path.join(checkpointDir, selectedFile);
- const data = await fs.readFile(filePath, 'utf-8');
- const toolCallData = JSON.parse(data);
-
- if (toolCallData.history) {
- loadHistory(toolCallData.history);
- }
-
- if (toolCallData.clientHistory) {
- await config
- ?.getGeminiClient()
- ?.setHistory(toolCallData.clientHistory);
- }
-
- if (toolCallData.commitHash) {
- await gitService?.restoreProjectFromSnapshot(
- toolCallData.commitHash,
- );
- addMessage({
- type: MessageType.INFO,
- content: `Restored project to the state before the tool call.`,
- timestamp: new Date(),
- });
- }
-
- return {
- type: 'tool',
- toolName: toolCallData.toolCall.name,
- toolArgs: toolCallData.toolCall.args,
- };
- } catch (error) {
- addMessage({
- type: MessageType.ERROR,
- content: `Could not read restorable tool calls. This is the error: ${error}`,
- timestamp: new Date(),
- });
- }
- },
- });
- }
return commands;
- }, [addMessage, toggleCorgiMode, config, gitService, loadHistory]);
+ }, [toggleCorgiMode]);
const handleSlashCommand = useCallback(
async (