diff options
Diffstat (limited to 'packages/cli/src')
| -rw-r--r-- | packages/cli/src/config/config.ts | 8 | ||||
| -rw-r--r-- | packages/cli/src/gemini.tsx | 10 | ||||
| -rw-r--r-- | packages/cli/src/ui/App.test.tsx | 3 | ||||
| -rw-r--r-- | packages/cli/src/ui/App.tsx | 6 | ||||
| -rw-r--r-- | packages/cli/src/ui/hooks/slashCommandProcessor.test.ts | 50 | ||||
| -rw-r--r-- | packages/cli/src/ui/hooks/slashCommandProcessor.ts | 185 | ||||
| -rw-r--r-- | packages/cli/src/ui/hooks/useGeminiStream.test.tsx | 35 | ||||
| -rw-r--r-- | packages/cli/src/ui/hooks/useGeminiStream.ts | 119 | ||||
| -rw-r--r-- | packages/cli/src/ui/hooks/useHistoryManager.ts | 6 | ||||
| -rw-r--r-- | packages/cli/src/ui/hooks/useLogger.ts | 3 | ||||
| -rw-r--r-- | packages/cli/src/utils/cleanup.ts | 18 |
11 files changed, 380 insertions, 63 deletions
diff --git a/packages/cli/src/config/config.ts b/packages/cli/src/config/config.ts index d1e7ea0c..1c8ef625 100644 --- a/packages/cli/src/config/config.ts +++ b/packages/cli/src/config/config.ts @@ -43,6 +43,7 @@ interface CliArgs { show_memory_usage: boolean | undefined; yolo: boolean | undefined; telemetry: boolean | undefined; + checkpoint: boolean | undefined; } async function parseArguments(): Promise<CliArgs> { @@ -91,6 +92,12 @@ async function parseArguments(): Promise<CliArgs> { type: 'boolean', description: 'Enable telemetry?', }) + .option('checkpoint', { + alias: 'c', + type: 'boolean', + description: 'Enables checkpointing of file edits', + default: false, + }) .version(process.env.CLI_VERSION || '0.0.0') // This will enable the --version flag based on package.json .help() .alias('h', 'help') @@ -178,6 +185,7 @@ export async function loadCliConfig( fileFilteringAllowBuildArtifacts: settings.fileFiltering?.allowBuildArtifacts, enableModifyWithExternalEditors: settings.enableModifyWithExternalEditors, + checkpoint: argv.checkpoint, }); } diff --git a/packages/cli/src/gemini.tsx b/packages/cli/src/gemini.tsx index eb4f6bb6..555a7c11 100644 --- a/packages/cli/src/gemini.tsx +++ b/packages/cli/src/gemini.tsx @@ -17,6 +17,7 @@ import { getStartupWarnings } from './utils/startupWarnings.js'; import { runNonInteractive } from './nonInteractiveCli.js'; import { loadGeminiIgnorePatterns } from './utils/loadIgnorePatterns.js'; import { loadExtensions, ExtensionConfig } from './config/extension.js'; +import { cleanupCheckpoints } from './utils/cleanup.js'; import { ApprovalMode, Config, @@ -40,7 +41,7 @@ export async function main() { setWindowTitle(basename(workspaceRoot), settings); const geminiIgnorePatterns = loadGeminiIgnorePatterns(workspaceRoot); - + await cleanupCheckpoints(); if (settings.errors.length > 0) { for (const error of settings.errors) { let errorMessage = `Error in ${error.path}: ${error.message}`; @@ -63,6 +64,13 @@ export async function main() { // Initialize centralized FileDiscoveryService await config.getFileService(); + if (config.getCheckpointEnabled()) { + try { + await config.getGitService(); + } catch { + // For now swallow the error, later log it. + } + } if (settings.merged.theme) { if (!themeManager.setActiveTheme(settings.merged.theme)) { diff --git a/packages/cli/src/ui/App.test.tsx b/packages/cli/src/ui/App.test.tsx index fefb2fe2..bfd2efaf 100644 --- a/packages/cli/src/ui/App.test.tsx +++ b/packages/cli/src/ui/App.test.tsx @@ -63,6 +63,7 @@ interface MockServerConfig { getVertexAI: Mock<() => boolean | undefined>; getShowMemoryUsage: Mock<() => boolean>; getAccessibility: Mock<() => AccessibilitySettings>; + getProjectRoot: Mock<() => string | undefined>; } // Mock @gemini-cli/core and its Config class @@ -120,7 +121,9 @@ vi.mock('@gemini-cli/core', async (importOriginal) => { getVertexAI: vi.fn(() => opts.vertexai), getShowMemoryUsage: vi.fn(() => opts.showMemoryUsage ?? false), getAccessibility: vi.fn(() => opts.accessibility ?? {}), + getProjectRoot: vi.fn(() => opts.projectRoot), getGeminiClient: vi.fn(() => ({})), + getCheckpointEnabled: vi.fn(() => opts.checkpoint ?? true), }; }); return { diff --git a/packages/cli/src/ui/App.tsx b/packages/cli/src/ui/App.tsx index bf8c2abb..cdec11e2 100644 --- a/packages/cli/src/ui/App.tsx +++ b/packages/cli/src/ui/App.tsx @@ -66,7 +66,7 @@ export const AppWrapper = (props: AppProps) => ( ); const App = ({ config, settings, startupWarnings = [] }: AppProps) => { - const { history, addItem, clearItems } = useHistory(); + const { history, addItem, clearItems, loadHistory } = useHistory(); const { consoleMessages, handleNewMessage, @@ -151,8 +151,10 @@ const App = ({ config, settings, startupWarnings = [] }: AppProps) => { const { handleSlashCommand, slashCommands } = useSlashCommandProcessor( config, + history, addItem, clearItems, + loadHistory, refreshStatic, setShowHelp, setDebugMessage, @@ -217,6 +219,7 @@ const App = ({ config, settings, startupWarnings = [] }: AppProps) => { const { streamingState, submitQuery, initError, pendingHistoryItems } = useGeminiStream( config.getGeminiClient(), + history, addItem, setShowHelp, config, @@ -512,7 +515,6 @@ const App = ({ config, settings, startupWarnings = [] }: AppProps) => { )} </Box> )} - <Footer model={config.getModel()} targetDir={config.getTargetDir()} diff --git a/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts b/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts index 6ec356aa..f16d3239 100644 --- a/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts +++ b/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts @@ -65,6 +65,14 @@ import { } from '@gemini-cli/core'; import { useSessionStats } from '../contexts/SessionContext.js'; +vi.mock('@gemini-code/core', async (importOriginal) => { + const actual = await importOriginal<typeof import('@gemini-code/core')>(); + return { + ...actual, + GitService: vi.fn(), + }; +}); + import * as ShowMemoryCommandModule from './useShowMemoryCommand.js'; import { GIT_COMMIT_INFO } from '../../generated/git-commit.js'; @@ -84,6 +92,7 @@ vi.mock('open', () => ({ describe('useSlashCommandProcessor', () => { let mockAddItem: ReturnType<typeof vi.fn>; let mockClearItems: ReturnType<typeof vi.fn>; + let mockLoadHistory: ReturnType<typeof vi.fn>; let mockRefreshStatic: ReturnType<typeof vi.fn>; let mockSetShowHelp: ReturnType<typeof vi.fn>; let mockOnDebugMessage: ReturnType<typeof vi.fn>; @@ -96,6 +105,7 @@ describe('useSlashCommandProcessor', () => { beforeEach(() => { mockAddItem = vi.fn(); mockClearItems = vi.fn(); + mockLoadHistory = vi.fn(); mockRefreshStatic = vi.fn(); mockSetShowHelp = vi.fn(); mockOnDebugMessage = vi.fn(); @@ -105,6 +115,8 @@ describe('useSlashCommandProcessor', () => { getDebugMode: vi.fn(() => false), getSandbox: vi.fn(() => 'test-sandbox'), getModel: vi.fn(() => 'test-model'), + getProjectRoot: vi.fn(() => '/test/dir'), + getCheckpointEnabled: vi.fn(() => true), } as unknown as Config; mockCorgiMode = vi.fn(); mockUseSessionStats.mockReturnValue({ @@ -133,8 +145,10 @@ describe('useSlashCommandProcessor', () => { const { result } = renderHook(() => useSlashCommandProcessor( mockConfig, + [], mockAddItem, mockClearItems, + mockLoadHistory, mockRefreshStatic, mockSetShowHelp, mockOnDebugMessage, @@ -153,7 +167,7 @@ describe('useSlashCommandProcessor', () => { const fact = 'Remember this fact'; let commandResult: SlashCommandActionReturn | boolean = false; await act(async () => { - commandResult = handleSlashCommand(`/memory add ${fact}`); + commandResult = await handleSlashCommand(`/memory add ${fact}`); }); expect(mockAddItem).toHaveBeenNthCalledWith( @@ -187,7 +201,7 @@ describe('useSlashCommandProcessor', () => { const { handleSlashCommand } = getProcessor(); let commandResult: SlashCommandActionReturn | boolean = false; await act(async () => { - commandResult = handleSlashCommand('/memory add '); + commandResult = await handleSlashCommand('/memory add '); }); expect(mockAddItem).toHaveBeenNthCalledWith( @@ -211,7 +225,7 @@ describe('useSlashCommandProcessor', () => { const { handleSlashCommand } = getProcessor(); let commandResult: SlashCommandActionReturn | boolean = false; await act(async () => { - commandResult = handleSlashCommand('/memory show'); + commandResult = await handleSlashCommand('/memory show'); }); expect( ShowMemoryCommandModule.createShowMemoryAction, @@ -226,7 +240,7 @@ describe('useSlashCommandProcessor', () => { const { handleSlashCommand } = getProcessor(); let commandResult: SlashCommandActionReturn | boolean = false; await act(async () => { - commandResult = handleSlashCommand('/memory refresh'); + commandResult = await handleSlashCommand('/memory refresh'); }); expect(mockPerformMemoryRefresh).toHaveBeenCalled(); expect(commandResult).toBe(true); @@ -238,7 +252,7 @@ describe('useSlashCommandProcessor', () => { const { handleSlashCommand } = getProcessor(); let commandResult: SlashCommandActionReturn | boolean = false; await act(async () => { - commandResult = handleSlashCommand('/memory foobar'); + commandResult = await handleSlashCommand('/memory foobar'); }); expect(mockAddItem).toHaveBeenNthCalledWith( 2, @@ -300,7 +314,7 @@ describe('useSlashCommandProcessor', () => { const { handleSlashCommand } = getProcessor(); let commandResult: SlashCommandActionReturn | boolean = false; await act(async () => { - commandResult = handleSlashCommand('/help'); + commandResult = await handleSlashCommand('/help'); }); expect(mockSetShowHelp).toHaveBeenCalledWith(true); expect(commandResult).toBe(true); @@ -373,7 +387,7 @@ Add any other context about the problem here. ); let commandResult: SlashCommandActionReturn | boolean = false; await act(async () => { - commandResult = handleSlashCommand(`/bug ${bugDescription}`); + commandResult = await handleSlashCommand(`/bug ${bugDescription}`); }); expect(mockAddItem).toHaveBeenCalledTimes(2); @@ -387,7 +401,7 @@ Add any other context about the problem here. const { handleSlashCommand } = getProcessor(); let commandResult: SlashCommandActionReturn | boolean = false; await act(async () => { - commandResult = handleSlashCommand('/unknowncommand'); + commandResult = await handleSlashCommand('/unknowncommand'); }); expect(mockAddItem).toHaveBeenNthCalledWith( 2, @@ -410,7 +424,7 @@ Add any other context about the problem here. const { handleSlashCommand } = getProcessor(); let commandResult: SlashCommandActionReturn | boolean = false; await act(async () => { - commandResult = handleSlashCommand('/tools'); + commandResult = await handleSlashCommand('/tools'); }); expect(mockAddItem).toHaveBeenNthCalledWith( @@ -434,7 +448,7 @@ Add any other context about the problem here. const { handleSlashCommand } = getProcessor(); let commandResult: SlashCommandActionReturn | boolean = false; await act(async () => { - commandResult = handleSlashCommand('/tools'); + commandResult = await handleSlashCommand('/tools'); }); expect(mockAddItem).toHaveBeenNthCalledWith( @@ -467,7 +481,7 @@ Add any other context about the problem here. const { handleSlashCommand } = getProcessor(); let commandResult: SlashCommandActionReturn | boolean = false; await act(async () => { - commandResult = handleSlashCommand('/tools'); + commandResult = await handleSlashCommand('/tools'); }); // Should only show tool1 and tool2, not the MCP tools @@ -499,7 +513,7 @@ Add any other context about the problem here. const { handleSlashCommand } = getProcessor(); let commandResult: SlashCommandActionReturn | boolean = false; await act(async () => { - commandResult = handleSlashCommand('/tools'); + commandResult = await handleSlashCommand('/tools'); }); expect(mockAddItem).toHaveBeenNthCalledWith( @@ -545,7 +559,7 @@ Add any other context about the problem here. const { handleSlashCommand } = getProcessor(); let commandResult: SlashCommandActionReturn | boolean = false; await act(async () => { - commandResult = handleSlashCommand('/mcp'); + commandResult = await handleSlashCommand('/mcp'); }); expect(mockAddItem).toHaveBeenNthCalledWith( @@ -571,7 +585,7 @@ Add any other context about the problem here. const { handleSlashCommand } = getProcessor(); let commandResult: SlashCommandActionReturn | boolean = false; await act(async () => { - commandResult = handleSlashCommand('/mcp'); + commandResult = await handleSlashCommand('/mcp'); }); expect(mockAddItem).toHaveBeenNthCalledWith( @@ -633,7 +647,7 @@ Add any other context about the problem here. const { handleSlashCommand } = getProcessor(); let commandResult: SlashCommandActionReturn | boolean = false; await act(async () => { - commandResult = handleSlashCommand('/mcp'); + commandResult = await handleSlashCommand('/mcp'); }); expect(mockAddItem).toHaveBeenNthCalledWith( @@ -706,7 +720,7 @@ Add any other context about the problem here. const { handleSlashCommand } = getProcessor(true); let commandResult: SlashCommandActionReturn | boolean = false; await act(async () => { - commandResult = handleSlashCommand('/mcp'); + commandResult = await handleSlashCommand('/mcp'); }); expect(mockAddItem).toHaveBeenNthCalledWith( @@ -780,7 +794,7 @@ Add any other context about the problem here. const { handleSlashCommand } = getProcessor(); let commandResult: SlashCommandActionReturn | boolean = false; await act(async () => { - commandResult = handleSlashCommand('/mcp'); + commandResult = await handleSlashCommand('/mcp'); }); expect(mockAddItem).toHaveBeenNthCalledWith( @@ -846,7 +860,7 @@ Add any other context about the problem here. const { handleSlashCommand } = getProcessor(); let commandResult: SlashCommandActionReturn | boolean = false; await act(async () => { - commandResult = handleSlashCommand('/mcp'); + commandResult = await handleSlashCommand('/mcp'); }); const message = mockAddItem.mock.calls[1][0].text; diff --git a/packages/cli/src/ui/hooks/slashCommandProcessor.ts b/packages/cli/src/ui/hooks/slashCommandProcessor.ts index 69fb6d06..3699b4e9 100644 --- a/packages/cli/src/ui/hooks/slashCommandProcessor.ts +++ b/packages/cli/src/ui/hooks/slashCommandProcessor.ts @@ -11,14 +11,22 @@ import process from 'node:process'; import { UseHistoryManagerReturn } from './useHistoryManager.js'; import { Config, + GitService, Logger, MCPDiscoveryState, MCPServerStatus, getMCPDiscoveryState, getMCPServerStatus, } from '@gemini-cli/core'; -import { Message, MessageType, HistoryItemWithoutId } from '../types.js'; import { useSessionStats } from '../contexts/SessionContext.js'; +import { + Message, + MessageType, + HistoryItemWithoutId, + HistoryItem, +} from '../types.js'; +import { promises as fs } from 'fs'; +import path from 'path'; import { createShowMemoryAction } from './useShowMemoryCommand.js'; import { GIT_COMMIT_INFO } from '../../generated/git-commit.js'; import { formatDuration, formatMemoryUsage } from '../utils/formatters.js'; @@ -39,7 +47,10 @@ export interface SlashCommand { mainCommand: string, subCommand?: string, args?: string, - ) => void | SlashCommandActionReturn; // Action can now return this object + ) => + | void + | SlashCommandActionReturn + | Promise<void | SlashCommandActionReturn>; // Action can now return this object } /** @@ -47,8 +58,10 @@ export interface SlashCommand { */ export const useSlashCommandProcessor = ( config: Config | null, + history: HistoryItem[], addItem: UseHistoryManagerReturn['addItem'], clearItems: UseHistoryManagerReturn['clearItems'], + loadHistory: UseHistoryManagerReturn['loadHistory'], refreshStatic: () => void, setShowHelp: React.Dispatch<React.SetStateAction<boolean>>, onDebugMessage: (message: string) => void, @@ -58,6 +71,13 @@ export const useSlashCommandProcessor = ( showToolDescriptions: boolean = false, ) => { const session = useSessionStats(); + const gitService = useMemo(() => { + if (!config?.getProjectRoot()) { + return; + } + return new GitService(config.getProjectRoot()); + }, [config]); + const addMessage = useCallback( (message: Message) => { // Convert Message to HistoryItemWithoutId @@ -126,8 +146,8 @@ export const useSlashCommandProcessor = ( [addMessage], ); - const slashCommands: SlashCommand[] = useMemo( - () => [ + const slashCommands: SlashCommand[] = useMemo(() => { + const commands: SlashCommand[] = [ { name: 'help', altName: '?', @@ -408,7 +428,9 @@ export const useSlashCommandProcessor = ( if (process.env.SANDBOX && process.env.SANDBOX !== 'sandbox-exec') { sandboxEnv = process.env.SANDBOX; } else if (process.env.SANDBOX === 'sandbox-exec') { - sandboxEnv = `sandbox-exec (${process.env.SEATBELT_PROFILE || 'unknown'})`; + sandboxEnv = `sandbox-exec (${ + process.env.SEATBELT_PROFILE || 'unknown' + })`; } const modelVersion = config?.getModel() || 'Unknown'; const cliVersion = getCliVersion(); @@ -437,7 +459,9 @@ export const useSlashCommandProcessor = ( if (process.env.SANDBOX && process.env.SANDBOX !== 'sandbox-exec') { sandboxEnv = process.env.SANDBOX.replace(/^gemini-(?:code-)?/, ''); } else if (process.env.SANDBOX === 'sandbox-exec') { - sandboxEnv = `sandbox-exec (${process.env.SEATBELT_PROFILE || 'unknown'})`; + sandboxEnv = `sandbox-exec (${ + process.env.SEATBELT_PROFILE || 'unknown' + })`; } const modelVersion = config?.getModel() || 'Unknown'; const memoryUsage = formatMemoryUsage(process.memoryUsage().rss); @@ -569,31 +593,140 @@ Add any other context about the problem here. name: 'quit', altName: 'exit', description: 'exit the cli', - action: (_mainCommand, _subCommand, _args) => { + action: async (_mainCommand, _subCommand, _args) => { onDebugMessage('Quitting. Good-bye.'); process.exit(0); }, }, - ], - [ - onDebugMessage, - setShowHelp, - refreshStatic, - openThemeDialog, - clearItems, - performMemoryRefresh, - showMemoryAction, - addMemoryAction, - addMessage, - toggleCorgiMode, - config, - showToolDescriptions, - session, - ], - ); + ]; + + if (config?.getCheckpointEnabled()) { + commands.push({ + name: 'restore', + description: + 'restore a tool call. This will reset the conversation and file history to the state it was in when the tool call was suggested', + action: async (_mainCommand, subCommand, _args) => { + const checkpointDir = config?.getGeminiDir() + ? path.join(config.getGeminiDir(), 'checkpoints') + : undefined; + + if (!checkpointDir) { + addMessage({ + type: MessageType.ERROR, + content: 'Could not determine the .gemini directory path.', + timestamp: new Date(), + }); + return; + } + + try { + // Ensure the directory exists before trying to read it. + await fs.mkdir(checkpointDir, { recursive: true }); + const files = await fs.readdir(checkpointDir); + const jsonFiles = files.filter((file) => file.endsWith('.json')); + + if (!subCommand) { + if (jsonFiles.length === 0) { + addMessage({ + type: MessageType.INFO, + content: 'No restorable tool calls found.', + timestamp: new Date(), + }); + return; + } + const truncatedFiles = jsonFiles.map((file) => { + const components = file.split('.'); + if (components.length <= 1) { + return file; + } + components.pop(); + return components.join('.'); + }); + const fileList = truncatedFiles.join('\n'); + addMessage({ + type: MessageType.INFO, + content: `Available tool calls to restore:\n\n${fileList}`, + timestamp: new Date(), + }); + return; + } + + const selectedFile = subCommand.endsWith('.json') + ? subCommand + : `${subCommand}.json`; + + if (!jsonFiles.includes(selectedFile)) { + addMessage({ + type: MessageType.ERROR, + content: `File not found: ${selectedFile}`, + timestamp: new Date(), + }); + return; + } + + const filePath = path.join(checkpointDir, selectedFile); + const data = await fs.readFile(filePath, 'utf-8'); + const toolCallData = JSON.parse(data); + + if (toolCallData.history) { + loadHistory(toolCallData.history); + } + + if (toolCallData.clientHistory) { + await config + ?.getGeminiClient() + ?.setHistory(toolCallData.clientHistory); + } + + if (toolCallData.commitHash) { + await gitService?.restoreProjectFromSnapshot( + toolCallData.commitHash, + ); + addMessage({ + type: MessageType.INFO, + content: `Restored project to the state before the tool call.`, + timestamp: new Date(), + }); + } + + return { + shouldScheduleTool: true, + toolName: toolCallData.toolCall.name, + toolArgs: toolCallData.toolCall.args, + }; + } catch (error) { + addMessage({ + type: MessageType.ERROR, + content: `Could not read restorable tool calls. This is the error: ${error}`, + timestamp: new Date(), + }); + } + }, + }); + } + return commands; + }, [ + onDebugMessage, + setShowHelp, + refreshStatic, + openThemeDialog, + clearItems, + performMemoryRefresh, + showMemoryAction, + addMemoryAction, + addMessage, + toggleCorgiMode, + config, + showToolDescriptions, + session, + gitService, + loadHistory, + ]); const handleSlashCommand = useCallback( - (rawQuery: PartListUnion): SlashCommandActionReturn | boolean => { + async ( + rawQuery: PartListUnion, + ): Promise<SlashCommandActionReturn | boolean> => { if (typeof rawQuery !== 'string') { return false; } @@ -625,7 +758,7 @@ Add any other context about the problem here. for (const cmd of slashCommands) { if (mainCommand === cmd.name || mainCommand === cmd.altName) { - const actionResult = cmd.action(mainCommand, subCommand, args); + const actionResult = await cmd.action(mainCommand, subCommand, args); if ( typeof actionResult === 'object' && actionResult?.shouldScheduleTool diff --git a/packages/cli/src/ui/hooks/useGeminiStream.test.tsx b/packages/cli/src/ui/hooks/useGeminiStream.test.tsx index e39feb01..81c7f52b 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.test.tsx +++ b/packages/cli/src/ui/hooks/useGeminiStream.test.tsx @@ -18,6 +18,7 @@ import { import { Config } from '@gemini-cli/core'; import { Part, PartListUnion } from '@google/genai'; import { UseHistoryManagerReturn } from './useHistoryManager.js'; +import { HistoryItem } from '../types.js'; import { Dispatch, SetStateAction } from 'react'; // --- MOCKS --- @@ -38,9 +39,9 @@ const MockedGeminiClientClass = vi.hoisted(() => vi.mock('@gemini-cli/core', async (importOriginal) => { const actualCoreModule = (await importOriginal()) as any; return { - ...(actualCoreModule || {}), - GeminiClient: MockedGeminiClientClass, // Export the class for type checking or other direct uses - Config: actualCoreModule.Config, // Ensure Config is passed through + ...actualCoreModule, + GitService: vi.fn(), + GeminiClient: MockedGeminiClientClass, }; }); @@ -277,11 +278,13 @@ describe('useGeminiStream', () => { getToolRegistry: vi.fn( () => ({ getToolSchemaList: vi.fn(() => []) }) as any, ), + getProjectRoot: vi.fn(() => '/test/dir'), + getCheckpointEnabled: vi.fn(() => false), getGeminiClient: mockGetGeminiClient, addHistory: vi.fn(), } as unknown as Config; mockOnDebugMessage = vi.fn(); - mockHandleSlashCommand = vi.fn().mockReturnValue(false); + mockHandleSlashCommand = vi.fn().mockResolvedValue(false); // Mock return value for useReactToolScheduler mockScheduleToolCalls = vi.fn(); @@ -322,19 +325,22 @@ describe('useGeminiStream', () => { const { result, rerender } = renderHook( (props: { client: any; + history: HistoryItem[]; addItem: UseHistoryManagerReturn['addItem']; setShowHelp: Dispatch<SetStateAction<boolean>>; config: Config; onDebugMessage: (message: string) => void; handleSlashCommand: ( - command: PartListUnion, - ) => + cmd: PartListUnion, + ) => Promise< | import('./slashCommandProcessor.js').SlashCommandActionReturn - | boolean; + | boolean + >; shellModeActive: boolean; }) => useGeminiStream( props.client, + props.history, props.addItem, props.setShowHelp, props.config, @@ -345,12 +351,17 @@ describe('useGeminiStream', () => { { initialProps: { client, + history: [], addItem: mockAddItem as unknown as UseHistoryManagerReturn['addItem'], setShowHelp: mockSetShowHelp, config: mockConfig, onDebugMessage: mockOnDebugMessage, - handleSlashCommand: - mockHandleSlashCommand as unknown as typeof mockHandleSlashCommand, + handleSlashCommand: mockHandleSlashCommand as unknown as ( + cmd: PartListUnion, + ) => Promise< + | import('./slashCommandProcessor.js').SlashCommandActionReturn + | boolean + >, shellModeActive: false, }, }, @@ -467,7 +478,8 @@ describe('useGeminiStream', () => { act(() => { rerender({ client, - addItem: mockAddItem as unknown as UseHistoryManagerReturn['addItem'], + history: [], + addItem: mockAddItem, setShowHelp: mockSetShowHelp, config: mockConfig, onDebugMessage: mockOnDebugMessage, @@ -521,7 +533,8 @@ describe('useGeminiStream', () => { act(() => { rerender({ client, - addItem: mockAddItem as unknown as UseHistoryManagerReturn['addItem'], + history: [], + addItem: mockAddItem, setShowHelp: mockSetShowHelp, config: mockConfig, onDebugMessage: mockOnDebugMessage, diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index 725d8737..7d0fe375 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -7,6 +7,7 @@ import { useState, useRef, useCallback, useEffect, useMemo } from 'react'; import { useInput } from 'ink'; import { + Config, GeminiClient, GeminiEventType as ServerGeminiEventType, ServerGeminiStreamEvent as GeminiEvent, @@ -14,14 +15,15 @@ import { ServerGeminiErrorEvent as ErrorEvent, getErrorMessage, isNodeError, - Config, MessageSenderType, ToolCallRequestInfo, logUserPrompt, + GitService, } from '@gemini-cli/core'; import { type Part, type PartListUnion } from '@google/genai'; import { StreamingState, + HistoryItem, HistoryItemWithoutId, HistoryItemToolGroup, MessageType, @@ -35,6 +37,8 @@ import { findLastSafeSplitPoint } from '../utils/markdownUtilities.js'; import { useStateAndRef } from './useStateAndRef.js'; import { UseHistoryManagerReturn } from './useHistoryManager.js'; import { useLogger } from './useLogger.js'; +import { promises as fs } from 'fs'; +import path from 'path'; import { useReactToolScheduler, mapToDisplay as mapTrackedToolCallsToDisplay, @@ -68,13 +72,16 @@ enum StreamProcessingStatus { */ export const useGeminiStream = ( geminiClient: GeminiClient | null, + history: HistoryItem[], addItem: UseHistoryManagerReturn['addItem'], setShowHelp: React.Dispatch<React.SetStateAction<boolean>>, config: Config, onDebugMessage: (message: string) => void, handleSlashCommand: ( cmd: PartListUnion, - ) => import('./slashCommandProcessor.js').SlashCommandActionReturn | boolean, + ) => Promise< + import('./slashCommandProcessor.js').SlashCommandActionReturn | boolean + >, shellModeActive: boolean, ) => { const [initError, setInitError] = useState<string | null>(null); @@ -84,6 +91,12 @@ export const useGeminiStream = ( useStateAndRef<HistoryItemWithoutId | null>(null); const logger = useLogger(); const { startNewTurn, addUsage } = useSessionStats(); + const gitService = useMemo(() => { + if (!config.getProjectRoot()) { + return; + } + return new GitService(config.getProjectRoot()); + }, [config]); const [toolCalls, scheduleToolCalls, markToolsAsSubmitted] = useReactToolScheduler( @@ -178,7 +191,7 @@ export const useGeminiStream = ( await logger?.logMessage(MessageSenderType.USER, trimmedQuery); // Handle UI-only commands first - const slashCommandResult = handleSlashCommand(trimmedQuery); + const slashCommandResult = await handleSlashCommand(trimmedQuery); if (typeof slashCommandResult === 'boolean' && slashCommandResult) { // Command was handled, and it doesn't require a tool call from here return { queryToSend: null, shouldProceed: false }; @@ -605,6 +618,106 @@ export const useGeminiStream = ( pendingToolCallGroupDisplay, ].filter((i) => i !== undefined && i !== null); + useEffect(() => { + const saveRestorableToolCalls = async () => { + if (!config.getCheckpointEnabled()) { + return; + } + const restorableToolCalls = toolCalls.filter( + (toolCall) => + (toolCall.request.name === 'replace' || + toolCall.request.name === 'write_file') && + toolCall.status === 'awaiting_approval', + ); + + if (restorableToolCalls.length > 0) { + const checkpointDir = config.getGeminiDir() + ? path.join(config.getGeminiDir(), 'checkpoints') + : undefined; + + if (!checkpointDir) { + return; + } + + try { + await fs.mkdir(checkpointDir, { recursive: true }); + } catch (error) { + if (!isNodeError(error) || error.code !== 'EEXIST') { + onDebugMessage( + `Failed to create checkpoint directory: ${getErrorMessage(error)}`, + ); + return; + } + } + + for (const toolCall of restorableToolCalls) { + const filePath = toolCall.request.args['file_path'] as string; + if (!filePath) { + onDebugMessage( + `Skipping restorable tool call due to missing file_path: ${toolCall.request.name}`, + ); + continue; + } + + try { + let commitHash = await gitService?.createFileSnapshot( + `Snapshot for ${toolCall.request.name}`, + ); + + if (!commitHash) { + commitHash = await gitService?.getCurrentCommitHash(); + } + + if (!commitHash) { + onDebugMessage( + `Failed to create snapshot for ${filePath}. Skipping restorable tool call.`, + ); + continue; + } + + const timestamp = new Date() + .toISOString() + .replace(/:/g, '-') + .replace(/\./g, '_'); + const toolName = toolCall.request.name; + const fileName = path.basename(filePath); + const toolCallWithSnapshotFileName = `${timestamp}-${fileName}-${toolName}.json`; + const clientHistory = await geminiClient?.getHistory(); + const toolCallWithSnapshotFilePath = path.join( + checkpointDir, + toolCallWithSnapshotFileName, + ); + + await fs.writeFile( + toolCallWithSnapshotFilePath, + JSON.stringify( + { + history, + clientHistory, + toolCall: { + name: toolCall.request.name, + args: toolCall.request.args, + }, + commitHash, + filePath, + }, + null, + 2, + ), + ); + } catch (error) { + onDebugMessage( + `Failed to write restorable tool call file: ${getErrorMessage( + error, + )}`, + ); + } + } + } + }; + saveRestorableToolCalls(); + }, [toolCalls, config, onDebugMessage, gitService, history, geminiClient]); + return { streamingState, submitQuery, diff --git a/packages/cli/src/ui/hooks/useHistoryManager.ts b/packages/cli/src/ui/hooks/useHistoryManager.ts index f82707ef..c45ac8d3 100644 --- a/packages/cli/src/ui/hooks/useHistoryManager.ts +++ b/packages/cli/src/ui/hooks/useHistoryManager.ts @@ -20,6 +20,7 @@ export interface UseHistoryManagerReturn { updates: Partial<Omit<HistoryItem, 'id'>> | HistoryItemUpdater, ) => void; clearItems: () => void; + loadHistory: (newHistory: HistoryItem[]) => void; } /** @@ -38,6 +39,10 @@ export function useHistory(): UseHistoryManagerReturn { return baseTimestamp + messageIdCounterRef.current; }, []); + const loadHistory = useCallback((newHistory: HistoryItem[]) => { + setHistory(newHistory); + }, []); + // Adds a new item to the history state with a unique ID. const addItem = useCallback( (itemData: Omit<HistoryItem, 'id'>, baseTimestamp: number): number => { @@ -101,5 +106,6 @@ export function useHistory(): UseHistoryManagerReturn { addItem, updateItem, clearItems, + loadHistory, }; } diff --git a/packages/cli/src/ui/hooks/useLogger.ts b/packages/cli/src/ui/hooks/useLogger.ts index ea6d6057..eda14187 100644 --- a/packages/cli/src/ui/hooks/useLogger.ts +++ b/packages/cli/src/ui/hooks/useLogger.ts @@ -5,8 +5,7 @@ */ import { useState, useEffect } from 'react'; -import { sessionId } from '@gemini-cli/core'; -import { Logger } from '@gemini-cli/core'; +import { sessionId, Logger } from '@gemini-cli/core'; /** * Hook to manage the logger instance. diff --git a/packages/cli/src/utils/cleanup.ts b/packages/cli/src/utils/cleanup.ts new file mode 100644 index 00000000..1e483373 --- /dev/null +++ b/packages/cli/src/utils/cleanup.ts @@ -0,0 +1,18 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { promises as fs } from 'fs'; +import { join } from 'path'; + +export async function cleanupCheckpoints() { + const geminiDir = join(process.cwd(), '.gemini'); + const checkpointsDir = join(geminiDir, 'checkpoints'); + try { + await fs.rm(checkpointsDir, { recursive: true, force: true }); + } catch { + // Ignore errors if the directory doesn't exist or fails to delete. + } +} |
