diff options
Diffstat (limited to 'packages/cli/src/ui/commands')
| -rw-r--r-- | packages/cli/src/ui/commands/restoreCommand.test.ts | 237 | ||||
| -rw-r--r-- | packages/cli/src/ui/commands/restoreCommand.ts | 155 | ||||
| -rw-r--r-- | packages/cli/src/ui/commands/types.ts | 6 |
3 files changed, 398 insertions, 0 deletions
diff --git a/packages/cli/src/ui/commands/restoreCommand.test.ts b/packages/cli/src/ui/commands/restoreCommand.test.ts new file mode 100644 index 00000000..53cd7d18 --- /dev/null +++ b/packages/cli/src/ui/commands/restoreCommand.test.ts @@ -0,0 +1,237 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + vi, + describe, + it, + expect, + beforeEach, + afterEach, + Mocked, + Mock, +} from 'vitest'; +import * as fs from 'fs/promises'; +import { restoreCommand } from './restoreCommand.js'; +import { type CommandContext } from './types.js'; +import { createMockCommandContext } from '../../test-utils/mockCommandContext.js'; +import { Config, GitService } from '@google/gemini-cli-core'; + +vi.mock('fs/promises', () => ({ + readdir: vi.fn(), + readFile: vi.fn(), + mkdir: vi.fn(), +})); + +describe('restoreCommand', () => { + let mockContext: CommandContext; + let mockConfig: Config; + let mockGitService: GitService; + const mockFsPromises = fs as Mocked<typeof fs>; + let mockSetHistory: ReturnType<typeof vi.fn>; + + beforeEach(() => { + mockSetHistory = vi.fn().mockResolvedValue(undefined); + mockGitService = { + restoreProjectFromSnapshot: vi.fn().mockResolvedValue(undefined), + } as unknown as GitService; + + mockConfig = { + getCheckpointingEnabled: vi.fn().mockReturnValue(true), + getProjectTempDir: vi.fn().mockReturnValue('/tmp/gemini'), + getGeminiClient: vi.fn().mockReturnValue({ + setHistory: mockSetHistory, + }), + } as unknown as Config; + + mockContext = createMockCommandContext({ + services: { + config: mockConfig, + git: mockGitService, + }, + }); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + it('should return null if checkpointing is not enabled', () => { + (mockConfig.getCheckpointingEnabled as Mock).mockReturnValue(false); + const command = restoreCommand(mockConfig); + expect(command).toBeNull(); + }); + + it('should return the command if checkpointing is enabled', () => { + const command = restoreCommand(mockConfig); + expect(command).not.toBeNull(); + expect(command?.name).toBe('restore'); + expect(command?.description).toBeDefined(); + expect(command?.action).toBeDefined(); + expect(command?.completion).toBeDefined(); + }); + + describe('action', () => { + it('should return an error if temp dir is not found', async () => { + (mockConfig.getProjectTempDir as Mock).mockReturnValue(undefined); + const command = restoreCommand(mockConfig); + const result = await command?.action?.(mockContext, ''); + expect(result).toEqual({ + type: 'message', + messageType: 'error', + content: 'Could not determine the .gemini directory path.', + }); + }); + + it('should inform when no checkpoints are found if no args are passed', async () => { + mockFsPromises.readdir.mockResolvedValue([]); + const command = restoreCommand(mockConfig); + const result = await command?.action?.(mockContext, ''); + expect(result).toEqual({ + type: 'message', + messageType: 'info', + content: 'No restorable tool calls found.', + }); + expect(mockFsPromises.mkdir).toHaveBeenCalledWith( + '/tmp/gemini/checkpoints', + { + recursive: true, + }, + ); + }); + + it('should list available checkpoints if no args are passed', async () => { + mockFsPromises.readdir.mockResolvedValue([ + 'test1.json', + 'test2.json', + // eslint-disable-next-line @typescript-eslint/no-explicit-any + ] as any); + const command = restoreCommand(mockConfig); + const result = await command?.action?.(mockContext, ''); + expect(result).toEqual({ + type: 'message', + messageType: 'info', + content: 'Available tool calls to restore:\n\ntest1\ntest2', + }); + }); + + it('should return an error if the specified file is not found', async () => { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + mockFsPromises.readdir.mockResolvedValue(['test1.json'] as any); + const command = restoreCommand(mockConfig); + const result = await command?.action?.(mockContext, 'test2'); + expect(result).toEqual({ + type: 'message', + messageType: 'error', + content: 'File not found: test2.json', + }); + }); + + it('should handle file read errors gracefully', async () => { + const readError = new Error('Read failed'); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + mockFsPromises.readdir.mockResolvedValue(['test1.json'] as any); + mockFsPromises.readFile.mockRejectedValue(readError); + const command = restoreCommand(mockConfig); + const result = await command?.action?.(mockContext, 'test1'); + expect(result).toEqual({ + type: 'message', + messageType: 'error', + content: `Could not read restorable tool calls. This is the error: ${readError}`, + }); + }); + + it('should restore a tool call and project state', async () => { + const toolCallData = { + history: [{ type: 'user', text: 'do a thing' }], + clientHistory: [{ role: 'user', parts: [{ text: 'do a thing' }] }], + commitHash: 'abcdef123', + toolCall: { name: 'run_shell_command', args: 'ls' }, + }; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + mockFsPromises.readdir.mockResolvedValue(['my-checkpoint.json'] as any); + mockFsPromises.readFile.mockResolvedValue(JSON.stringify(toolCallData)); + + const command = restoreCommand(mockConfig); + const result = await command?.action?.(mockContext, 'my-checkpoint'); + + // Check history restoration + expect(mockContext.ui.loadHistory).toHaveBeenCalledWith( + toolCallData.history, + ); + expect(mockSetHistory).toHaveBeenCalledWith(toolCallData.clientHistory); + + // Check git restoration + expect(mockGitService.restoreProjectFromSnapshot).toHaveBeenCalledWith( + toolCallData.commitHash, + ); + expect(mockContext.ui.addItem).toHaveBeenCalledWith( + { + type: 'info', + text: 'Restored project to the state before the tool call.', + }, + expect.any(Number), + ); + + // Check returned action + expect(result).toEqual({ + type: 'tool', + toolName: 'run_shell_command', + toolArgs: 'ls', + }); + }); + + it('should restore even if only toolCall is present', async () => { + const toolCallData = { + toolCall: { name: 'run_shell_command', args: 'ls' }, + }; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + mockFsPromises.readdir.mockResolvedValue(['my-checkpoint.json'] as any); + mockFsPromises.readFile.mockResolvedValue(JSON.stringify(toolCallData)); + + const command = restoreCommand(mockConfig); + const result = await command?.action?.(mockContext, 'my-checkpoint'); + + expect(mockContext.ui.loadHistory).not.toHaveBeenCalled(); + expect(mockSetHistory).not.toHaveBeenCalled(); + expect(mockGitService.restoreProjectFromSnapshot).not.toHaveBeenCalled(); + + expect(result).toEqual({ + type: 'tool', + toolName: 'run_shell_command', + toolArgs: 'ls', + }); + }); + }); + + describe('completion', () => { + it('should return an empty array if temp dir is not found', async () => { + (mockConfig.getProjectTempDir as Mock).mockReturnValue(undefined); + const command = restoreCommand(mockConfig); + const result = await command?.completion?.(mockContext, ''); + expect(result).toEqual([]); + }); + + it('should return an empty array on readdir error', async () => { + mockFsPromises.readdir.mockRejectedValue(new Error('ENOENT')); + const command = restoreCommand(mockConfig); + const result = await command?.completion?.(mockContext, ''); + expect(result).toEqual([]); + }); + + it('should return a list of checkpoint names', async () => { + mockFsPromises.readdir.mockResolvedValue([ + 'test1.json', + 'test2.json', + 'not-a-checkpoint.txt', + // eslint-disable-next-line @typescript-eslint/no-explicit-any + ] as any); + const command = restoreCommand(mockConfig); + const result = await command?.completion?.(mockContext, ''); + expect(result).toEqual(['test1', 'test2']); + }); + }); +}); diff --git a/packages/cli/src/ui/commands/restoreCommand.ts b/packages/cli/src/ui/commands/restoreCommand.ts new file mode 100644 index 00000000..3d744189 --- /dev/null +++ b/packages/cli/src/ui/commands/restoreCommand.ts @@ -0,0 +1,155 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import * as fs from 'fs/promises'; +import path from 'path'; +import { + type CommandContext, + type SlashCommand, + type SlashCommandActionReturn, +} from './types.js'; +import { Config } from '@google/gemini-cli-core'; + +async function restoreAction( + context: CommandContext, + args: string, +): Promise<void | SlashCommandActionReturn> { + const { services, ui } = context; + const { config, git: gitService } = services; + const { addItem, loadHistory } = ui; + + const checkpointDir = config?.getProjectTempDir() + ? path.join(config.getProjectTempDir(), 'checkpoints') + : undefined; + + if (!checkpointDir) { + return { + type: 'message', + messageType: 'error', + content: 'Could not determine the .gemini directory path.', + }; + } + + try { + // Ensure the directory exists before trying to read it. + await fs.mkdir(checkpointDir, { recursive: true }); + const files = await fs.readdir(checkpointDir); + const jsonFiles = files.filter((file) => file.endsWith('.json')); + + if (!args) { + if (jsonFiles.length === 0) { + return { + type: 'message', + messageType: 'info', + content: 'No restorable tool calls found.', + }; + } + const truncatedFiles = jsonFiles.map((file) => { + const components = file.split('.'); + if (components.length <= 1) { + return file; + } + components.pop(); + return components.join('.'); + }); + const fileList = truncatedFiles.join('\n'); + return { + type: 'message', + messageType: 'info', + content: `Available tool calls to restore:\n\n${fileList}`, + }; + } + + const selectedFile = args.endsWith('.json') ? args : `${args}.json`; + + if (!jsonFiles.includes(selectedFile)) { + return { + type: 'message', + messageType: 'error', + content: `File not found: ${selectedFile}`, + }; + } + + const filePath = path.join(checkpointDir, selectedFile); + const data = await fs.readFile(filePath, 'utf-8'); + const toolCallData = JSON.parse(data); + + if (toolCallData.history) { + if (!loadHistory) { + // This should not happen + return { + type: 'message', + messageType: 'error', + content: 'loadHistory function is not available.', + }; + } + loadHistory(toolCallData.history); + } + + if (toolCallData.clientHistory) { + await config?.getGeminiClient()?.setHistory(toolCallData.clientHistory); + } + + if (toolCallData.commitHash) { + await gitService?.restoreProjectFromSnapshot(toolCallData.commitHash); + addItem( + { + type: 'info', + text: 'Restored project to the state before the tool call.', + }, + Date.now(), + ); + } + + return { + type: 'tool', + toolName: toolCallData.toolCall.name, + toolArgs: toolCallData.toolCall.args, + }; + } catch (error) { + return { + type: 'message', + messageType: 'error', + content: `Could not read restorable tool calls. This is the error: ${error}`, + }; + } +} + +async function completion( + context: CommandContext, + _partialArg: string, +): Promise<string[]> { + const { services } = context; + const { config } = services; + const checkpointDir = config?.getProjectTempDir() + ? path.join(config.getProjectTempDir(), 'checkpoints') + : undefined; + if (!checkpointDir) { + return []; + } + try { + const files = await fs.readdir(checkpointDir); + return files + .filter((file) => file.endsWith('.json')) + .map((file) => file.replace('.json', '')); + } catch (_err) { + return []; + } +} + +export const restoreCommand = (config: Config | null): SlashCommand | null => { + if (!config?.getCheckpointingEnabled()) { + return null; + } + + return { + name: 'restore', + description: + 'Restore a tool call. This will reset the conversation and file history to the state it was in when the tool call was suggested', + action: restoreAction, + completion, + }; +}; diff --git a/packages/cli/src/ui/commands/types.ts b/packages/cli/src/ui/commands/types.ts index d3d5ee8a..51b66fb4 100644 --- a/packages/cli/src/ui/commands/types.ts +++ b/packages/cli/src/ui/commands/types.ts @@ -41,6 +41,12 @@ export interface CommandContext { * @param item The history item to display as pending, or `null` to clear. */ setPendingItem: (item: HistoryItemWithoutId | null) => void; + /** + * Loads a new set of history items, replacing the current history. + * + * @param history The array of history items to load. + */ + loadHistory: UseHistoryManagerReturn['loadHistory']; }; // Session-specific data session: { |
