diff options
| -rw-r--r-- | package-lock.json | 31 | ||||
| -rw-r--r-- | packages/cli/src/config/config.ts | 8 | ||||
| -rw-r--r-- | packages/cli/src/gemini.tsx | 10 | ||||
| -rw-r--r-- | packages/cli/src/ui/App.test.tsx | 3 | ||||
| -rw-r--r-- | packages/cli/src/ui/App.tsx | 6 | ||||
| -rw-r--r-- | packages/cli/src/ui/hooks/slashCommandProcessor.test.ts | 50 | ||||
| -rw-r--r-- | packages/cli/src/ui/hooks/slashCommandProcessor.ts | 185 | ||||
| -rw-r--r-- | packages/cli/src/ui/hooks/useGeminiStream.test.tsx | 35 | ||||
| -rw-r--r-- | packages/cli/src/ui/hooks/useGeminiStream.ts | 119 | ||||
| -rw-r--r-- | packages/cli/src/ui/hooks/useHistoryManager.ts | 6 | ||||
| -rw-r--r-- | packages/cli/src/ui/hooks/useLogger.ts | 3 | ||||
| -rw-r--r-- | packages/cli/src/utils/cleanup.ts | 18 | ||||
| -rw-r--r-- | packages/core/package.json | 1 | ||||
| -rw-r--r-- | packages/core/src/config/config.ts | 25 | ||||
| -rw-r--r-- | packages/core/src/core/client.ts | 10 | ||||
| -rw-r--r-- | packages/core/src/core/geminiChat.ts | 3 | ||||
| -rw-r--r-- | packages/core/src/index.ts | 1 | ||||
| -rw-r--r-- | packages/core/src/services/gitService.test.ts | 254 | ||||
| -rw-r--r-- | packages/core/src/services/gitService.ts | 132 |
19 files changed, 837 insertions, 63 deletions
diff --git a/package-lock.json b/package-lock.json index cefcb757..2f55b474 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1166,6 +1166,21 @@ "tslib": "2" } }, + "node_modules/@kwsites/file-exists": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/@kwsites/file-exists/-/file-exists-1.1.1.tgz", + "integrity": "sha512-m9/5YGR18lIwxSFDwfE3oA7bWuq9kdau6ugN4H2rJeyhFQZcG9AgSHkQtSD15a8WvTgfz9aikZMrKPHvbpqFiw==", + "license": "MIT", + "dependencies": { + "debug": "^4.1.1" + } + }, + "node_modules/@kwsites/promise-deferred": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/@kwsites/promise-deferred/-/promise-deferred-1.1.1.tgz", + "integrity": "sha512-GaHYm+c0O9MjZRu0ongGBRbinu8gVAMd2UZjji6jVmqKtZluZnptXGWhz1E8j8D2HJ3f/yMxKAUC0b+57wncIw==", + "license": "MIT" + }, "node_modules/@modelcontextprotocol/sdk": { "version": "1.12.1", "resolved": "https://registry.npmjs.org/@modelcontextprotocol/sdk/-/sdk-1.12.1.tgz", @@ -8522,6 +8537,21 @@ "url": "https://github.com/sponsors/isaacs" } }, + "node_modules/simple-git": { + "version": "3.28.0", + "resolved": "https://registry.npmjs.org/simple-git/-/simple-git-3.28.0.tgz", + "integrity": "sha512-Rs/vQRwsn1ILH1oBUy8NucJlXmnnLeLCfcvbSehkPzbv3wwoFWIdtfd6Ndo6ZPhlPsCZ60CPI4rxurnwAa+a2w==", + "license": "MIT", + "dependencies": { + "@kwsites/file-exists": "^1.1.1", + "@kwsites/promise-deferred": "^1.1.1", + "debug": "^4.4.0" + }, + "funding": { + "type": "github", + "url": "https://github.com/steveukx/git-js?sponsor=1" + } + }, "node_modules/slice-ansi": { "version": "7.1.0", "resolved": "https://registry.npmjs.org/slice-ansi/-/slice-ansi-7.1.0.tgz", @@ -10765,6 +10795,7 @@ "ignore": "^7.0.0", "open": "^10.1.2", "shell-quote": "^1.8.2", + "simple-git": "^3.27.0", "strip-ansi": "^7.1.0", "undici": "^7.10.0" }, diff --git a/packages/cli/src/config/config.ts b/packages/cli/src/config/config.ts index d1e7ea0c..1c8ef625 100644 --- a/packages/cli/src/config/config.ts +++ b/packages/cli/src/config/config.ts @@ -43,6 +43,7 @@ interface CliArgs { show_memory_usage: boolean | undefined; yolo: boolean | undefined; telemetry: boolean | undefined; + checkpoint: boolean | undefined; } async function parseArguments(): Promise<CliArgs> { @@ -91,6 +92,12 @@ async function parseArguments(): Promise<CliArgs> { type: 'boolean', description: 'Enable telemetry?', }) + .option('checkpoint', { + alias: 'c', + type: 'boolean', + description: 'Enables checkpointing of file edits', + default: false, + }) .version(process.env.CLI_VERSION || '0.0.0') // This will enable the --version flag based on package.json .help() .alias('h', 'help') @@ -178,6 +185,7 @@ export async function loadCliConfig( fileFilteringAllowBuildArtifacts: settings.fileFiltering?.allowBuildArtifacts, enableModifyWithExternalEditors: settings.enableModifyWithExternalEditors, + checkpoint: argv.checkpoint, }); } diff --git a/packages/cli/src/gemini.tsx b/packages/cli/src/gemini.tsx index eb4f6bb6..555a7c11 100644 --- a/packages/cli/src/gemini.tsx +++ b/packages/cli/src/gemini.tsx @@ -17,6 +17,7 @@ import { getStartupWarnings } from './utils/startupWarnings.js'; import { runNonInteractive } from './nonInteractiveCli.js'; import { loadGeminiIgnorePatterns } from './utils/loadIgnorePatterns.js'; import { loadExtensions, ExtensionConfig } from './config/extension.js'; +import { cleanupCheckpoints } from './utils/cleanup.js'; import { ApprovalMode, Config, @@ -40,7 +41,7 @@ export async function main() { setWindowTitle(basename(workspaceRoot), settings); const geminiIgnorePatterns = loadGeminiIgnorePatterns(workspaceRoot); - + await cleanupCheckpoints(); if (settings.errors.length > 0) { for (const error of settings.errors) { let errorMessage = `Error in ${error.path}: ${error.message}`; @@ -63,6 +64,13 @@ export async function main() { // Initialize centralized FileDiscoveryService await config.getFileService(); + if (config.getCheckpointEnabled()) { + try { + await config.getGitService(); + } catch { + // For now swallow the error, later log it. + } + } if (settings.merged.theme) { if (!themeManager.setActiveTheme(settings.merged.theme)) { diff --git a/packages/cli/src/ui/App.test.tsx b/packages/cli/src/ui/App.test.tsx index fefb2fe2..bfd2efaf 100644 --- a/packages/cli/src/ui/App.test.tsx +++ b/packages/cli/src/ui/App.test.tsx @@ -63,6 +63,7 @@ interface MockServerConfig { getVertexAI: Mock<() => boolean | undefined>; getShowMemoryUsage: Mock<() => boolean>; getAccessibility: Mock<() => AccessibilitySettings>; + getProjectRoot: Mock<() => string | undefined>; } // Mock @gemini-cli/core and its Config class @@ -120,7 +121,9 @@ vi.mock('@gemini-cli/core', async (importOriginal) => { getVertexAI: vi.fn(() => opts.vertexai), getShowMemoryUsage: vi.fn(() => opts.showMemoryUsage ?? false), getAccessibility: vi.fn(() => opts.accessibility ?? {}), + getProjectRoot: vi.fn(() => opts.projectRoot), getGeminiClient: vi.fn(() => ({})), + getCheckpointEnabled: vi.fn(() => opts.checkpoint ?? true), }; }); return { diff --git a/packages/cli/src/ui/App.tsx b/packages/cli/src/ui/App.tsx index bf8c2abb..cdec11e2 100644 --- a/packages/cli/src/ui/App.tsx +++ b/packages/cli/src/ui/App.tsx @@ -66,7 +66,7 @@ export const AppWrapper = (props: AppProps) => ( ); const App = ({ config, settings, startupWarnings = [] }: AppProps) => { - const { history, addItem, clearItems } = useHistory(); + const { history, addItem, clearItems, loadHistory } = useHistory(); const { consoleMessages, handleNewMessage, @@ -151,8 +151,10 @@ const App = ({ config, settings, startupWarnings = [] }: AppProps) => { const { handleSlashCommand, slashCommands } = useSlashCommandProcessor( config, + history, addItem, clearItems, + loadHistory, refreshStatic, setShowHelp, setDebugMessage, @@ -217,6 +219,7 @@ const App = ({ config, settings, startupWarnings = [] }: AppProps) => { const { streamingState, submitQuery, initError, pendingHistoryItems } = useGeminiStream( config.getGeminiClient(), + history, addItem, setShowHelp, config, @@ -512,7 +515,6 @@ const App = ({ config, settings, startupWarnings = [] }: AppProps) => { )} </Box> )} - <Footer model={config.getModel()} targetDir={config.getTargetDir()} diff --git a/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts b/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts index 6ec356aa..f16d3239 100644 --- a/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts +++ b/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts @@ -65,6 +65,14 @@ import { } from '@gemini-cli/core'; import { useSessionStats } from '../contexts/SessionContext.js'; +vi.mock('@gemini-code/core', async (importOriginal) => { + const actual = await importOriginal<typeof import('@gemini-code/core')>(); + return { + ...actual, + GitService: vi.fn(), + }; +}); + import * as ShowMemoryCommandModule from './useShowMemoryCommand.js'; import { GIT_COMMIT_INFO } from '../../generated/git-commit.js'; @@ -84,6 +92,7 @@ vi.mock('open', () => ({ describe('useSlashCommandProcessor', () => { let mockAddItem: ReturnType<typeof vi.fn>; let mockClearItems: ReturnType<typeof vi.fn>; + let mockLoadHistory: ReturnType<typeof vi.fn>; let mockRefreshStatic: ReturnType<typeof vi.fn>; let mockSetShowHelp: ReturnType<typeof vi.fn>; let mockOnDebugMessage: ReturnType<typeof vi.fn>; @@ -96,6 +105,7 @@ describe('useSlashCommandProcessor', () => { beforeEach(() => { mockAddItem = vi.fn(); mockClearItems = vi.fn(); + mockLoadHistory = vi.fn(); mockRefreshStatic = vi.fn(); mockSetShowHelp = vi.fn(); mockOnDebugMessage = vi.fn(); @@ -105,6 +115,8 @@ describe('useSlashCommandProcessor', () => { getDebugMode: vi.fn(() => false), getSandbox: vi.fn(() => 'test-sandbox'), getModel: vi.fn(() => 'test-model'), + getProjectRoot: vi.fn(() => '/test/dir'), + getCheckpointEnabled: vi.fn(() => true), } as unknown as Config; mockCorgiMode = vi.fn(); mockUseSessionStats.mockReturnValue({ @@ -133,8 +145,10 @@ describe('useSlashCommandProcessor', () => { const { result } = renderHook(() => useSlashCommandProcessor( mockConfig, + [], mockAddItem, mockClearItems, + mockLoadHistory, mockRefreshStatic, mockSetShowHelp, mockOnDebugMessage, @@ -153,7 +167,7 @@ describe('useSlashCommandProcessor', () => { const fact = 'Remember this fact'; let commandResult: SlashCommandActionReturn | boolean = false; await act(async () => { - commandResult = handleSlashCommand(`/memory add ${fact}`); + commandResult = await handleSlashCommand(`/memory add ${fact}`); }); expect(mockAddItem).toHaveBeenNthCalledWith( @@ -187,7 +201,7 @@ describe('useSlashCommandProcessor', () => { const { handleSlashCommand } = getProcessor(); let commandResult: SlashCommandActionReturn | boolean = false; await act(async () => { - commandResult = handleSlashCommand('/memory add '); + commandResult = await handleSlashCommand('/memory add '); }); expect(mockAddItem).toHaveBeenNthCalledWith( @@ -211,7 +225,7 @@ describe('useSlashCommandProcessor', () => { const { handleSlashCommand } = getProcessor(); let commandResult: SlashCommandActionReturn | boolean = false; await act(async () => { - commandResult = handleSlashCommand('/memory show'); + commandResult = await handleSlashCommand('/memory show'); }); expect( ShowMemoryCommandModule.createShowMemoryAction, @@ -226,7 +240,7 @@ describe('useSlashCommandProcessor', () => { const { handleSlashCommand } = getProcessor(); let commandResult: SlashCommandActionReturn | boolean = false; await act(async () => { - commandResult = handleSlashCommand('/memory refresh'); + commandResult = await handleSlashCommand('/memory refresh'); }); expect(mockPerformMemoryRefresh).toHaveBeenCalled(); expect(commandResult).toBe(true); @@ -238,7 +252,7 @@ describe('useSlashCommandProcessor', () => { const { handleSlashCommand } = getProcessor(); let commandResult: SlashCommandActionReturn | boolean = false; await act(async () => { - commandResult = handleSlashCommand('/memory foobar'); + commandResult = await handleSlashCommand('/memory foobar'); }); expect(mockAddItem).toHaveBeenNthCalledWith( 2, @@ -300,7 +314,7 @@ describe('useSlashCommandProcessor', () => { const { handleSlashCommand } = getProcessor(); let commandResult: SlashCommandActionReturn | boolean = false; await act(async () => { - commandResult = handleSlashCommand('/help'); + commandResult = await handleSlashCommand('/help'); }); expect(mockSetShowHelp).toHaveBeenCalledWith(true); expect(commandResult).toBe(true); @@ -373,7 +387,7 @@ Add any other context about the problem here. ); let commandResult: SlashCommandActionReturn | boolean = false; await act(async () => { - commandResult = handleSlashCommand(`/bug ${bugDescription}`); + commandResult = await handleSlashCommand(`/bug ${bugDescription}`); }); expect(mockAddItem).toHaveBeenCalledTimes(2); @@ -387,7 +401,7 @@ Add any other context about the problem here. const { handleSlashCommand } = getProcessor(); let commandResult: SlashCommandActionReturn | boolean = false; await act(async () => { - commandResult = handleSlashCommand('/unknowncommand'); + commandResult = await handleSlashCommand('/unknowncommand'); }); expect(mockAddItem).toHaveBeenNthCalledWith( 2, @@ -410,7 +424,7 @@ Add any other context about the problem here. const { handleSlashCommand } = getProcessor(); let commandResult: SlashCommandActionReturn | boolean = false; await act(async () => { - commandResult = handleSlashCommand('/tools'); + commandResult = await handleSlashCommand('/tools'); }); expect(mockAddItem).toHaveBeenNthCalledWith( @@ -434,7 +448,7 @@ Add any other context about the problem here. const { handleSlashCommand } = getProcessor(); let commandResult: SlashCommandActionReturn | boolean = false; await act(async () => { - commandResult = handleSlashCommand('/tools'); + commandResult = await handleSlashCommand('/tools'); }); expect(mockAddItem).toHaveBeenNthCalledWith( @@ -467,7 +481,7 @@ Add any other context about the problem here. const { handleSlashCommand } = getProcessor(); let commandResult: SlashCommandActionReturn | boolean = false; await act(async () => { - commandResult = handleSlashCommand('/tools'); + commandResult = await handleSlashCommand('/tools'); }); // Should only show tool1 and tool2, not the MCP tools @@ -499,7 +513,7 @@ Add any other context about the problem here. const { handleSlashCommand } = getProcessor(); let commandResult: SlashCommandActionReturn | boolean = false; await act(async () => { - commandResult = handleSlashCommand('/tools'); + commandResult = await handleSlashCommand('/tools'); }); expect(mockAddItem).toHaveBeenNthCalledWith( @@ -545,7 +559,7 @@ Add any other context about the problem here. const { handleSlashCommand } = getProcessor(); let commandResult: SlashCommandActionReturn | boolean = false; await act(async () => { - commandResult = handleSlashCommand('/mcp'); + commandResult = await handleSlashCommand('/mcp'); }); expect(mockAddItem).toHaveBeenNthCalledWith( @@ -571,7 +585,7 @@ Add any other context about the problem here. const { handleSlashCommand } = getProcessor(); let commandResult: SlashCommandActionReturn | boolean = false; await act(async () => { - commandResult = handleSlashCommand('/mcp'); + commandResult = await handleSlashCommand('/mcp'); }); expect(mockAddItem).toHaveBeenNthCalledWith( @@ -633,7 +647,7 @@ Add any other context about the problem here. const { handleSlashCommand } = getProcessor(); let commandResult: SlashCommandActionReturn | boolean = false; await act(async () => { - commandResult = handleSlashCommand('/mcp'); + commandResult = await handleSlashCommand('/mcp'); }); expect(mockAddItem).toHaveBeenNthCalledWith( @@ -706,7 +720,7 @@ Add any other context about the problem here. const { handleSlashCommand } = getProcessor(true); let commandResult: SlashCommandActionReturn | boolean = false; await act(async () => { - commandResult = handleSlashCommand('/mcp'); + commandResult = await handleSlashCommand('/mcp'); }); expect(mockAddItem).toHaveBeenNthCalledWith( @@ -780,7 +794,7 @@ Add any other context about the problem here. const { handleSlashCommand } = getProcessor(); let commandResult: SlashCommandActionReturn | boolean = false; await act(async () => { - commandResult = handleSlashCommand('/mcp'); + commandResult = await handleSlashCommand('/mcp'); }); expect(mockAddItem).toHaveBeenNthCalledWith( @@ -846,7 +860,7 @@ Add any other context about the problem here. const { handleSlashCommand } = getProcessor(); let commandResult: SlashCommandActionReturn | boolean = false; await act(async () => { - commandResult = handleSlashCommand('/mcp'); + commandResult = await handleSlashCommand('/mcp'); }); const message = mockAddItem.mock.calls[1][0].text; diff --git a/packages/cli/src/ui/hooks/slashCommandProcessor.ts b/packages/cli/src/ui/hooks/slashCommandProcessor.ts index 69fb6d06..3699b4e9 100644 --- a/packages/cli/src/ui/hooks/slashCommandProcessor.ts +++ b/packages/cli/src/ui/hooks/slashCommandProcessor.ts @@ -11,14 +11,22 @@ import process from 'node:process'; import { UseHistoryManagerReturn } from './useHistoryManager.js'; import { Config, + GitService, Logger, MCPDiscoveryState, MCPServerStatus, getMCPDiscoveryState, getMCPServerStatus, } from '@gemini-cli/core'; -import { Message, MessageType, HistoryItemWithoutId } from '../types.js'; import { useSessionStats } from '../contexts/SessionContext.js'; +import { + Message, + MessageType, + HistoryItemWithoutId, + HistoryItem, +} from '../types.js'; +import { promises as fs } from 'fs'; +import path from 'path'; import { createShowMemoryAction } from './useShowMemoryCommand.js'; import { GIT_COMMIT_INFO } from '../../generated/git-commit.js'; import { formatDuration, formatMemoryUsage } from '../utils/formatters.js'; @@ -39,7 +47,10 @@ export interface SlashCommand { mainCommand: string, subCommand?: string, args?: string, - ) => void | SlashCommandActionReturn; // Action can now return this object + ) => + | void + | SlashCommandActionReturn + | Promise<void | SlashCommandActionReturn>; // Action can now return this object } /** @@ -47,8 +58,10 @@ export interface SlashCommand { */ export const useSlashCommandProcessor = ( config: Config | null, + history: HistoryItem[], addItem: UseHistoryManagerReturn['addItem'], clearItems: UseHistoryManagerReturn['clearItems'], + loadHistory: UseHistoryManagerReturn['loadHistory'], refreshStatic: () => void, setShowHelp: React.Dispatch<React.SetStateAction<boolean>>, onDebugMessage: (message: string) => void, @@ -58,6 +71,13 @@ export const useSlashCommandProcessor = ( showToolDescriptions: boolean = false, ) => { const session = useSessionStats(); + const gitService = useMemo(() => { + if (!config?.getProjectRoot()) { + return; + } + return new GitService(config.getProjectRoot()); + }, [config]); + const addMessage = useCallback( (message: Message) => { // Convert Message to HistoryItemWithoutId @@ -126,8 +146,8 @@ export const useSlashCommandProcessor = ( [addMessage], ); - const slashCommands: SlashCommand[] = useMemo( - () => [ + const slashCommands: SlashCommand[] = useMemo(() => { + const commands: SlashCommand[] = [ { name: 'help', altName: '?', @@ -408,7 +428,9 @@ export const useSlashCommandProcessor = ( if (process.env.SANDBOX && process.env.SANDBOX !== 'sandbox-exec') { sandboxEnv = process.env.SANDBOX; } else if (process.env.SANDBOX === 'sandbox-exec') { - sandboxEnv = `sandbox-exec (${process.env.SEATBELT_PROFILE || 'unknown'})`; + sandboxEnv = `sandbox-exec (${ + process.env.SEATBELT_PROFILE || 'unknown' + })`; } const modelVersion = config?.getModel() || 'Unknown'; const cliVersion = getCliVersion(); @@ -437,7 +459,9 @@ export const useSlashCommandProcessor = ( if (process.env.SANDBOX && process.env.SANDBOX !== 'sandbox-exec') { sandboxEnv = process.env.SANDBOX.replace(/^gemini-(?:code-)?/, ''); } else if (process.env.SANDBOX === 'sandbox-exec') { - sandboxEnv = `sandbox-exec (${process.env.SEATBELT_PROFILE || 'unknown'})`; + sandboxEnv = `sandbox-exec (${ + process.env.SEATBELT_PROFILE || 'unknown' + })`; } const modelVersion = config?.getModel() || 'Unknown'; const memoryUsage = formatMemoryUsage(process.memoryUsage().rss); @@ -569,31 +593,140 @@ Add any other context about the problem here. name: 'quit', altName: 'exit', description: 'exit the cli', - action: (_mainCommand, _subCommand, _args) => { + action: async (_mainCommand, _subCommand, _args) => { onDebugMessage('Quitting. Good-bye.'); process.exit(0); }, }, - ], - [ - onDebugMessage, - setShowHelp, - refreshStatic, - openThemeDialog, - clearItems, - performMemoryRefresh, - showMemoryAction, - addMemoryAction, - addMessage, - toggleCorgiMode, - config, - showToolDescriptions, - session, - ], - ); + ]; + + if (config?.getCheckpointEnabled()) { + commands.push({ + name: 'restore', + description: + 'restore a tool call. This will reset the conversation and file history to the state it was in when the tool call was suggested', + action: async (_mainCommand, subCommand, _args) => { + const checkpointDir = config?.getGeminiDir() + ? path.join(config.getGeminiDir(), 'checkpoints') + : undefined; + + if (!checkpointDir) { + addMessage({ + type: MessageType.ERROR, + content: 'Could not determine the .gemini directory path.', + timestamp: new Date(), + }); + return; + } + + try { + // Ensure the directory exists before trying to read it. + await fs.mkdir(checkpointDir, { recursive: true }); + const files = await fs.readdir(checkpointDir); + const jsonFiles = files.filter((file) => file.endsWith('.json')); + + if (!subCommand) { + if (jsonFiles.length === 0) { + addMessage({ + type: MessageType.INFO, + content: 'No restorable tool calls found.', + timestamp: new Date(), + }); + return; + } + const truncatedFiles = jsonFiles.map((file) => { + const components = file.split('.'); + if (components.length <= 1) { + return file; + } + components.pop(); + return components.join('.'); + }); + const fileList = truncatedFiles.join('\n'); + addMessage({ + type: MessageType.INFO, + content: `Available tool calls to restore:\n\n${fileList}`, + timestamp: new Date(), + }); + return; + } + + const selectedFile = subCommand.endsWith('.json') + ? subCommand + : `${subCommand}.json`; + + if (!jsonFiles.includes(selectedFile)) { + addMessage({ + type: MessageType.ERROR, + content: `File not found: ${selectedFile}`, + timestamp: new Date(), + }); + return; + } + + const filePath = path.join(checkpointDir, selectedFile); + const data = await fs.readFile(filePath, 'utf-8'); + const toolCallData = JSON.parse(data); + + if (toolCallData.history) { + loadHistory(toolCallData.history); + } + + if (toolCallData.clientHistory) { + await config + ?.getGeminiClient() + ?.setHistory(toolCallData.clientHistory); + } + + if (toolCallData.commitHash) { + await gitService?.restoreProjectFromSnapshot( + toolCallData.commitHash, + ); + addMessage({ + type: MessageType.INFO, + content: `Restored project to the state before the tool call.`, + timestamp: new Date(), + }); + } + + return { + shouldScheduleTool: true, + toolName: toolCallData.toolCall.name, + toolArgs: toolCallData.toolCall.args, + }; + } catch (error) { + addMessage({ + type: MessageType.ERROR, + content: `Could not read restorable tool calls. This is the error: ${error}`, + timestamp: new Date(), + }); + } + }, + }); + } + return commands; + }, [ + onDebugMessage, + setShowHelp, + refreshStatic, + openThemeDialog, + clearItems, + performMemoryRefresh, + showMemoryAction, + addMemoryAction, + addMessage, + toggleCorgiMode, + config, + showToolDescriptions, + session, + gitService, + loadHistory, + ]); const handleSlashCommand = useCallback( - (rawQuery: PartListUnion): SlashCommandActionReturn | boolean => { + async ( + rawQuery: PartListUnion, + ): Promise<SlashCommandActionReturn | boolean> => { if (typeof rawQuery !== 'string') { return false; } @@ -625,7 +758,7 @@ Add any other context about the problem here. for (const cmd of slashCommands) { if (mainCommand === cmd.name || mainCommand === cmd.altName) { - const actionResult = cmd.action(mainCommand, subCommand, args); + const actionResult = await cmd.action(mainCommand, subCommand, args); if ( typeof actionResult === 'object' && actionResult?.shouldScheduleTool diff --git a/packages/cli/src/ui/hooks/useGeminiStream.test.tsx b/packages/cli/src/ui/hooks/useGeminiStream.test.tsx index e39feb01..81c7f52b 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.test.tsx +++ b/packages/cli/src/ui/hooks/useGeminiStream.test.tsx @@ -18,6 +18,7 @@ import { import { Config } from '@gemini-cli/core'; import { Part, PartListUnion } from '@google/genai'; import { UseHistoryManagerReturn } from './useHistoryManager.js'; +import { HistoryItem } from '../types.js'; import { Dispatch, SetStateAction } from 'react'; // --- MOCKS --- @@ -38,9 +39,9 @@ const MockedGeminiClientClass = vi.hoisted(() => vi.mock('@gemini-cli/core', async (importOriginal) => { const actualCoreModule = (await importOriginal()) as any; return { - ...(actualCoreModule || {}), - GeminiClient: MockedGeminiClientClass, // Export the class for type checking or other direct uses - Config: actualCoreModule.Config, // Ensure Config is passed through + ...actualCoreModule, + GitService: vi.fn(), + GeminiClient: MockedGeminiClientClass, }; }); @@ -277,11 +278,13 @@ describe('useGeminiStream', () => { getToolRegistry: vi.fn( () => ({ getToolSchemaList: vi.fn(() => []) }) as any, ), + getProjectRoot: vi.fn(() => '/test/dir'), + getCheckpointEnabled: vi.fn(() => false), getGeminiClient: mockGetGeminiClient, addHistory: vi.fn(), } as unknown as Config; mockOnDebugMessage = vi.fn(); - mockHandleSlashCommand = vi.fn().mockReturnValue(false); + mockHandleSlashCommand = vi.fn().mockResolvedValue(false); // Mock return value for useReactToolScheduler mockScheduleToolCalls = vi.fn(); @@ -322,19 +325,22 @@ describe('useGeminiStream', () => { const { result, rerender } = renderHook( (props: { client: any; + history: HistoryItem[]; addItem: UseHistoryManagerReturn['addItem']; setShowHelp: Dispatch<SetStateAction<boolean>>; config: Config; onDebugMessage: (message: string) => void; handleSlashCommand: ( - command: PartListUnion, - ) => + cmd: PartListUnion, + ) => Promise< | import('./slashCommandProcessor.js').SlashCommandActionReturn - | boolean; + | boolean + >; shellModeActive: boolean; }) => useGeminiStream( props.client, + props.history, props.addItem, props.setShowHelp, props.config, @@ -345,12 +351,17 @@ describe('useGeminiStream', () => { { initialProps: { client, + history: [], addItem: mockAddItem as unknown as UseHistoryManagerReturn['addItem'], setShowHelp: mockSetShowHelp, config: mockConfig, onDebugMessage: mockOnDebugMessage, - handleSlashCommand: - mockHandleSlashCommand as unknown as typeof mockHandleSlashCommand, + handleSlashCommand: mockHandleSlashCommand as unknown as ( + cmd: PartListUnion, + ) => Promise< + | import('./slashCommandProcessor.js').SlashCommandActionReturn + | boolean + >, shellModeActive: false, }, }, @@ -467,7 +478,8 @@ describe('useGeminiStream', () => { act(() => { rerender({ client, - addItem: mockAddItem as unknown as UseHistoryManagerReturn['addItem'], + history: [], + addItem: mockAddItem, setShowHelp: mockSetShowHelp, config: mockConfig, onDebugMessage: mockOnDebugMessage, @@ -521,7 +533,8 @@ describe('useGeminiStream', () => { act(() => { rerender({ client, - addItem: mockAddItem as unknown as UseHistoryManagerReturn['addItem'], + history: [], + addItem: mockAddItem, setShowHelp: mockSetShowHelp, config: mockConfig, onDebugMessage: mockOnDebugMessage, diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index 725d8737..7d0fe375 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -7,6 +7,7 @@ import { useState, useRef, useCallback, useEffect, useMemo } from 'react'; import { useInput } from 'ink'; import { + Config, GeminiClient, GeminiEventType as ServerGeminiEventType, ServerGeminiStreamEvent as GeminiEvent, @@ -14,14 +15,15 @@ import { ServerGeminiErrorEvent as ErrorEvent, getErrorMessage, isNodeError, - Config, MessageSenderType, ToolCallRequestInfo, logUserPrompt, + GitService, } from '@gemini-cli/core'; import { type Part, type PartListUnion } from '@google/genai'; import { StreamingState, + HistoryItem, HistoryItemWithoutId, HistoryItemToolGroup, MessageType, @@ -35,6 +37,8 @@ import { findLastSafeSplitPoint } from '../utils/markdownUtilities.js'; import { useStateAndRef } from './useStateAndRef.js'; import { UseHistoryManagerReturn } from './useHistoryManager.js'; import { useLogger } from './useLogger.js'; +import { promises as fs } from 'fs'; +import path from 'path'; import { useReactToolScheduler, mapToDisplay as mapTrackedToolCallsToDisplay, @@ -68,13 +72,16 @@ enum StreamProcessingStatus { */ export const useGeminiStream = ( geminiClient: GeminiClient | null, + history: HistoryItem[], addItem: UseHistoryManagerReturn['addItem'], setShowHelp: React.Dispatch<React.SetStateAction<boolean>>, config: Config, onDebugMessage: (message: string) => void, handleSlashCommand: ( cmd: PartListUnion, - ) => import('./slashCommandProcessor.js').SlashCommandActionReturn | boolean, + ) => Promise< + import('./slashCommandProcessor.js').SlashCommandActionReturn | boolean + >, shellModeActive: boolean, ) => { const [initError, setInitError] = useState<string | null>(null); @@ -84,6 +91,12 @@ export const useGeminiStream = ( useStateAndRef<HistoryItemWithoutId | null>(null); const logger = useLogger(); const { startNewTurn, addUsage } = useSessionStats(); + const gitService = useMemo(() => { + if (!config.getProjectRoot()) { + return; + } + return new GitService(config.getProjectRoot()); + }, [config]); const [toolCalls, scheduleToolCalls, markToolsAsSubmitted] = useReactToolScheduler( @@ -178,7 +191,7 @@ export const useGeminiStream = ( await logger?.logMessage(MessageSenderType.USER, trimmedQuery); // Handle UI-only commands first - const slashCommandResult = handleSlashCommand(trimmedQuery); + const slashCommandResult = await handleSlashCommand(trimmedQuery); if (typeof slashCommandResult === 'boolean' && slashCommandResult) { // Command was handled, and it doesn't require a tool call from here return { queryToSend: null, shouldProceed: false }; @@ -605,6 +618,106 @@ export const useGeminiStream = ( pendingToolCallGroupDisplay, ].filter((i) => i !== undefined && i !== null); + useEffect(() => { + const saveRestorableToolCalls = async () => { + if (!config.getCheckpointEnabled()) { + return; + } + const restorableToolCalls = toolCalls.filter( + (toolCall) => + (toolCall.request.name === 'replace' || + toolCall.request.name === 'write_file') && + toolCall.status === 'awaiting_approval', + ); + + if (restorableToolCalls.length > 0) { + const checkpointDir = config.getGeminiDir() + ? path.join(config.getGeminiDir(), 'checkpoints') + : undefined; + + if (!checkpointDir) { + return; + } + + try { + await fs.mkdir(checkpointDir, { recursive: true }); + } catch (error) { + if (!isNodeError(error) || error.code !== 'EEXIST') { + onDebugMessage( + `Failed to create checkpoint directory: ${getErrorMessage(error)}`, + ); + return; + } + } + + for (const toolCall of restorableToolCalls) { + const filePath = toolCall.request.args['file_path'] as string; + if (!filePath) { + onDebugMessage( + `Skipping restorable tool call due to missing file_path: ${toolCall.request.name}`, + ); + continue; + } + + try { + let commitHash = await gitService?.createFileSnapshot( + `Snapshot for ${toolCall.request.name}`, + ); + + if (!commitHash) { + commitHash = await gitService?.getCurrentCommitHash(); + } + + if (!commitHash) { + onDebugMessage( + `Failed to create snapshot for ${filePath}. Skipping restorable tool call.`, + ); + continue; + } + + const timestamp = new Date() + .toISOString() + .replace(/:/g, '-') + .replace(/\./g, '_'); + const toolName = toolCall.request.name; + const fileName = path.basename(filePath); + const toolCallWithSnapshotFileName = `${timestamp}-${fileName}-${toolName}.json`; + const clientHistory = await geminiClient?.getHistory(); + const toolCallWithSnapshotFilePath = path.join( + checkpointDir, + toolCallWithSnapshotFileName, + ); + + await fs.writeFile( + toolCallWithSnapshotFilePath, + JSON.stringify( + { + history, + clientHistory, + toolCall: { + name: toolCall.request.name, + args: toolCall.request.args, + }, + commitHash, + filePath, + }, + null, + 2, + ), + ); + } catch (error) { + onDebugMessage( + `Failed to write restorable tool call file: ${getErrorMessage( + error, + )}`, + ); + } + } + } + }; + saveRestorableToolCalls(); + }, [toolCalls, config, onDebugMessage, gitService, history, geminiClient]); + return { streamingState, submitQuery, diff --git a/packages/cli/src/ui/hooks/useHistoryManager.ts b/packages/cli/src/ui/hooks/useHistoryManager.ts index f82707ef..c45ac8d3 100644 --- a/packages/cli/src/ui/hooks/useHistoryManager.ts +++ b/packages/cli/src/ui/hooks/useHistoryManager.ts @@ -20,6 +20,7 @@ export interface UseHistoryManagerReturn { updates: Partial<Omit<HistoryItem, 'id'>> | HistoryItemUpdater, ) => void; clearItems: () => void; + loadHistory: (newHistory: HistoryItem[]) => void; } /** @@ -38,6 +39,10 @@ export function useHistory(): UseHistoryManagerReturn { return baseTimestamp + messageIdCounterRef.current; }, []); + const loadHistory = useCallback((newHistory: HistoryItem[]) => { + setHistory(newHistory); + }, []); + // Adds a new item to the history state with a unique ID. const addItem = useCallback( (itemData: Omit<HistoryItem, 'id'>, baseTimestamp: number): number => { @@ -101,5 +106,6 @@ export function useHistory(): UseHistoryManagerReturn { addItem, updateItem, clearItems, + loadHistory, }; } diff --git a/packages/cli/src/ui/hooks/useLogger.ts b/packages/cli/src/ui/hooks/useLogger.ts index ea6d6057..eda14187 100644 --- a/packages/cli/src/ui/hooks/useLogger.ts +++ b/packages/cli/src/ui/hooks/useLogger.ts @@ -5,8 +5,7 @@ */ import { useState, useEffect } from 'react'; -import { sessionId } from '@gemini-cli/core'; -import { Logger } from '@gemini-cli/core'; +import { sessionId, Logger } from '@gemini-cli/core'; /** * Hook to manage the logger instance. diff --git a/packages/cli/src/utils/cleanup.ts b/packages/cli/src/utils/cleanup.ts new file mode 100644 index 00000000..1e483373 --- /dev/null +++ b/packages/cli/src/utils/cleanup.ts @@ -0,0 +1,18 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { promises as fs } from 'fs'; +import { join } from 'path'; + +export async function cleanupCheckpoints() { + const geminiDir = join(process.cwd(), '.gemini'); + const checkpointsDir = join(geminiDir, 'checkpoints'); + try { + await fs.rm(checkpointsDir, { recursive: true, force: true }); + } catch { + // Ignore errors if the directory doesn't exist or fails to delete. + } +} diff --git a/packages/core/package.json b/packages/core/package.json index d4432e9a..8d11bdd5 100644 --- a/packages/core/package.json +++ b/packages/core/package.json @@ -36,6 +36,7 @@ "ignore": "^7.0.0", "open": "^10.1.2", "shell-quote": "^1.8.2", + "simple-git": "^3.27.0", "strip-ansi": "^7.1.0", "undici": "^7.10.0" }, diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index d42fbbec..297178fd 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -25,6 +25,7 @@ import { WebSearchTool } from '../tools/web-search.js'; import { GeminiClient } from '../core/client.js'; import { GEMINI_CONFIG_DIR as GEMINI_DIR } from '../tools/memoryTool.js'; import { FileDiscoveryService } from '../services/fileDiscoveryService.js'; +import { GitService } from '../services/gitService.js'; import { initializeTelemetry } from '../telemetry/index.js'; export enum ApprovalMode { @@ -80,6 +81,7 @@ export interface ConfigParameters { fileFilteringRespectGitIgnore?: boolean; fileFilteringAllowBuildArtifacts?: boolean; enableModifyWithExternalEditors?: boolean; + checkpoint?: boolean; } export class Config { @@ -111,6 +113,8 @@ export class Config { private readonly fileFilteringAllowBuildArtifacts: boolean; private readonly enableModifyWithExternalEditors: boolean; private fileDiscoveryService: FileDiscoveryService | null = null; + private gitService: GitService | undefined = undefined; + private readonly checkpoint: boolean; constructor(params: ConfigParameters) { this.sessionId = params.sessionId; @@ -142,6 +146,7 @@ export class Config { params.fileFilteringAllowBuildArtifacts ?? false; this.enableModifyWithExternalEditors = params.enableModifyWithExternalEditors ?? false; + this.checkpoint = params.checkpoint ?? false; if (params.contextFileName) { setGeminiMdFilename(params.contextFileName); @@ -182,6 +187,10 @@ export class Config { return this.targetDir; } + getProjectRoot(): string { + return this.targetDir; + } + async getToolRegistry(): Promise<ToolRegistry> { return this.toolRegistry; } @@ -265,6 +274,10 @@ export class Config { return this.geminiClient; } + getGeminiDir(): string { + return path.join(this.targetDir, GEMINI_DIR); + } + getGeminiIgnorePatterns(): string[] { return this.geminiIgnorePatterns; } @@ -281,6 +294,10 @@ export class Config { return this.enableModifyWithExternalEditors; } + getCheckpointEnabled(): boolean { + return this.checkpoint; + } + async getFileService(): Promise<FileDiscoveryService> { if (!this.fileDiscoveryService) { this.fileDiscoveryService = new FileDiscoveryService(this.targetDir); @@ -291,6 +308,14 @@ export class Config { } return this.fileDiscoveryService; } + + async getGitService(): Promise<GitService> { + if (!this.gitService) { + this.gitService = new GitService(this.targetDir); + await this.gitService.initialize(); + } + return this.gitService; + } } function findEnvFile(startDir: string): string | null { diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts index 596ddcd7..4e4dc55e 100644 --- a/packages/core/src/core/client.ts +++ b/packages/core/src/core/client.ts @@ -77,6 +77,16 @@ export class GeminiClient { return this.chat; } + async getHistory(): Promise<Content[]> { + const chat = await this.chat; + return chat.getHistory(); + } + + async setHistory(history: Content[]): Promise<void> { + const chat = await this.chat; + chat.setHistory(history); + } + private async getEnvironment(): Promise<Part[]> { const cwd = process.cwd(); const today = new Date().toLocaleDateString(undefined, { diff --git a/packages/core/src/core/geminiChat.ts b/packages/core/src/core/geminiChat.ts index 2a81aca8..d15f9d1a 100644 --- a/packages/core/src/core/geminiChat.ts +++ b/packages/core/src/core/geminiChat.ts @@ -297,6 +297,9 @@ export class GeminiChat { addHistory(content: Content): void { this.history.push(content); } + setHistory(history: Content[]): void { + this.history = history; + } private async *processStreamResponse( streamResponse: AsyncGenerator<GenerateContentResponse>, diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index 60959281..09ad1e92 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -29,6 +29,7 @@ export * from './utils/editor.js'; // Export services export * from './services/fileDiscoveryService.js'; +export * from './services/gitService.js'; // Export base tool definitions export * from './tools/tools.js'; diff --git a/packages/core/src/services/gitService.test.ts b/packages/core/src/services/gitService.test.ts new file mode 100644 index 00000000..67c3c091 --- /dev/null +++ b/packages/core/src/services/gitService.test.ts @@ -0,0 +1,254 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; +import { GitService, historyDirName } from './gitService.js'; +import * as path from 'path'; +import type * as FsPromisesModule from 'fs/promises'; +import type { ChildProcess } from 'node:child_process'; + +const hoistedMockExec = vi.hoisted(() => vi.fn()); +vi.mock('node:child_process', () => ({ + exec: hoistedMockExec, +})); + +const hoistedMockMkdir = vi.hoisted(() => vi.fn()); +const hoistedMockReadFile = vi.hoisted(() => vi.fn()); +const hoistedMockWriteFile = vi.hoisted(() => vi.fn()); + +vi.mock('fs/promises', async (importOriginal) => { + const actual = (await importOriginal()) as typeof FsPromisesModule; + return { + ...actual, + mkdir: hoistedMockMkdir, + readFile: hoistedMockReadFile, + writeFile: hoistedMockWriteFile, + }; +}); + +const hoistedMockSimpleGit = vi.hoisted(() => vi.fn()); +const hoistedMockCheckIsRepo = vi.hoisted(() => vi.fn()); +const hoistedMockInit = vi.hoisted(() => vi.fn()); +const hoistedMockRaw = vi.hoisted(() => vi.fn()); +const hoistedMockAdd = vi.hoisted(() => vi.fn()); +const hoistedMockCommit = vi.hoisted(() => vi.fn()); +vi.mock('simple-git', () => ({ + simpleGit: hoistedMockSimpleGit.mockImplementation(() => ({ + checkIsRepo: hoistedMockCheckIsRepo, + init: hoistedMockInit, + raw: hoistedMockRaw, + add: hoistedMockAdd, + commit: hoistedMockCommit, + })), + CheckRepoActions: { IS_REPO_ROOT: 'is-repo-root' }, +})); + +const hoistedIsGitRepositoryMock = vi.hoisted(() => vi.fn()); +vi.mock('../utils/gitUtils.js', () => ({ + isGitRepository: hoistedIsGitRepositoryMock, +})); + +const hoistedMockIsNodeError = vi.hoisted(() => vi.fn()); +vi.mock('../utils/errors.js', () => ({ + isNodeError: hoistedMockIsNodeError, +})); + +describe('GitService', () => { + const mockProjectRoot = '/test/project'; + + beforeEach(() => { + vi.clearAllMocks(); + hoistedIsGitRepositoryMock.mockReturnValue(true); + hoistedMockExec.mockImplementation((command, callback) => { + if (command === 'git --version') { + callback(null, 'git version 2.0.0'); + } else { + callback(new Error('Command not mocked')); + } + return {}; + }); + hoistedMockMkdir.mockResolvedValue(undefined); + hoistedMockReadFile.mockResolvedValue(''); + hoistedMockWriteFile.mockResolvedValue(undefined); + hoistedMockIsNodeError.mockImplementation((e) => e instanceof Error); + + hoistedMockSimpleGit.mockImplementation(() => ({ + checkIsRepo: hoistedMockCheckIsRepo, + init: hoistedMockInit, + raw: hoistedMockRaw, + add: hoistedMockAdd, + commit: hoistedMockCommit, + })); + hoistedMockCheckIsRepo.mockResolvedValue(false); + hoistedMockInit.mockResolvedValue(undefined); + hoistedMockRaw.mockResolvedValue(''); + hoistedMockAdd.mockResolvedValue(undefined); + hoistedMockCommit.mockResolvedValue({ + commit: 'initial', + }); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + describe('constructor', () => { + it('should successfully create an instance if projectRoot is a Git repository', () => { + expect(() => new GitService(mockProjectRoot)).not.toThrow(); + }); + }); + + describe('verifyGitAvailability', () => { + it('should resolve true if git --version command succeeds', async () => { + const service = new GitService(mockProjectRoot); + await expect(service.verifyGitAvailability()).resolves.toBe(true); + }); + + it('should resolve false if git --version command fails', async () => { + hoistedMockExec.mockImplementation((command, callback) => { + callback(new Error('git not found')); + return {} as ChildProcess; + }); + const service = new GitService(mockProjectRoot); + await expect(service.verifyGitAvailability()).resolves.toBe(false); + }); + }); + + describe('initialize', () => { + it('should throw an error if projectRoot is not a Git repository', async () => { + hoistedIsGitRepositoryMock.mockReturnValue(false); + const service = new GitService(mockProjectRoot); + await expect(service.initialize()).rejects.toThrow( + 'GitService requires a Git repository', + ); + }); + + it('should throw an error if Git is not available', async () => { + hoistedMockExec.mockImplementation((command, callback) => { + callback(new Error('git not found')); + return {} as ChildProcess; + }); + const service = new GitService(mockProjectRoot); + await expect(service.initialize()).rejects.toThrow( + 'GitService requires Git to be installed', + ); + }); + }); + + it('should call setupHiddenGitRepository if Git is available', async () => { + const service = new GitService(mockProjectRoot); + const setupSpy = vi + .spyOn(service, 'setupHiddenGitRepository') + .mockResolvedValue(undefined); + + await service.initialize(); + expect(setupSpy).toHaveBeenCalled(); + }); + + describe('setupHiddenGitRepository', () => { + const historyDir = path.join(mockProjectRoot, historyDirName); + const repoDir = path.join(historyDir, 'repository'); + const hiddenGitIgnorePath = path.join(repoDir, '.gitignore'); + const visibleGitIgnorePath = path.join(mockProjectRoot, '.gitignore'); + + it('should create history and repository directories', async () => { + const service = new GitService(mockProjectRoot); + await service.setupHiddenGitRepository(); + expect(hoistedMockMkdir).toHaveBeenCalledWith(repoDir, { + recursive: true, + }); + }); + + it('should initialize git repo in historyDir if not already initialized', async () => { + hoistedMockCheckIsRepo.mockResolvedValue(false); + const service = new GitService(mockProjectRoot); + await service.setupHiddenGitRepository(); + expect(hoistedMockSimpleGit).toHaveBeenCalledWith(repoDir); + expect(hoistedMockInit).toHaveBeenCalled(); + }); + + it('should not initialize git repo if already initialized', async () => { + hoistedMockCheckIsRepo.mockResolvedValue(true); + const service = new GitService(mockProjectRoot); + await service.setupHiddenGitRepository(); + expect(hoistedMockInit).not.toHaveBeenCalled(); + }); + + it('should copy .gitignore from projectRoot if it exists', async () => { + const gitignoreContent = `node_modules/\n.env`; + hoistedMockReadFile.mockImplementation(async (filePath) => { + if (filePath === visibleGitIgnorePath) { + return gitignoreContent; + } + return ''; + }); + const service = new GitService(mockProjectRoot); + await service.setupHiddenGitRepository(); + expect(hoistedMockReadFile).toHaveBeenCalledWith( + visibleGitIgnorePath, + 'utf-8', + ); + expect(hoistedMockWriteFile).toHaveBeenCalledWith( + hiddenGitIgnorePath, + gitignoreContent, + ); + }); + + it('should throw an error if reading projectRoot .gitignore fails with other errors', async () => { + const readError = new Error('Read permission denied'); + hoistedMockReadFile.mockImplementation(async (filePath) => { + if (filePath === visibleGitIgnorePath) { + throw readError; + } + return ''; + }); + hoistedMockIsNodeError.mockImplementation( + (e: unknown): e is NodeJS.ErrnoException => + e === readError && + e instanceof Error && + (e as NodeJS.ErrnoException).code !== 'ENOENT', + ); + + const service = new GitService(mockProjectRoot); + await expect(service.setupHiddenGitRepository()).rejects.toThrow( + 'Read permission denied', + ); + }); + + it('should add historyDirName to projectRoot .gitignore if not present', async () => { + const initialGitignoreContent = 'node_modules/'; + hoistedMockReadFile.mockImplementation(async (filePath) => { + if (filePath === visibleGitIgnorePath) { + return initialGitignoreContent; + } + return ''; + }); + const service = new GitService(mockProjectRoot); + await service.setupHiddenGitRepository(); + const expectedContent = `${initialGitignoreContent}\n# Gemini CLI history directory\n${historyDirName}\n`; + expect(hoistedMockWriteFile).toHaveBeenCalledWith( + visibleGitIgnorePath, + expectedContent, + ); + }); + + it('should make an initial commit if no commits exist in history repo', async () => { + hoistedMockRaw.mockResolvedValue(''); + const service = new GitService(mockProjectRoot); + await service.setupHiddenGitRepository(); + expect(hoistedMockAdd).toHaveBeenCalledWith(hiddenGitIgnorePath); + expect(hoistedMockCommit).toHaveBeenCalledWith('Initial commit'); + }); + + it('should not make an initial commit if commits already exist', async () => { + hoistedMockRaw.mockResolvedValue('test-commit'); + const service = new GitService(mockProjectRoot); + await service.setupHiddenGitRepository(); + expect(hoistedMockAdd).not.toHaveBeenCalled(); + expect(hoistedMockCommit).not.toHaveBeenCalled(); + }); + }); +}); diff --git a/packages/core/src/services/gitService.ts b/packages/core/src/services/gitService.ts new file mode 100644 index 00000000..8cd6b887 --- /dev/null +++ b/packages/core/src/services/gitService.ts @@ -0,0 +1,132 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import * as fs from 'fs/promises'; +import * as path from 'path'; +import { isNodeError } from '../utils/errors.js'; +import { isGitRepository } from '../utils/gitUtils.js'; +import { exec } from 'node:child_process'; +import { simpleGit, SimpleGit, CheckRepoActions } from 'simple-git'; + +export const historyDirName = '.gemini_cli_history'; + +export class GitService { + private projectRoot: string; + + constructor(projectRoot: string) { + this.projectRoot = path.resolve(projectRoot); + } + + async initialize(): Promise<void> { + if (!isGitRepository(this.projectRoot)) { + throw new Error('GitService requires a Git repository'); + } + const gitAvailable = await this.verifyGitAvailability(); + if (!gitAvailable) { + throw new Error('GitService requires Git to be installed'); + } + this.setupHiddenGitRepository(); + } + + verifyGitAvailability(): Promise<boolean> { + return new Promise((resolve) => { + exec('git --version', (error) => { + if (error) { + resolve(false); + } else { + resolve(true); + } + }); + }); + } + + /** + * Creates a hidden git repository in the project root. + * The Git repository is used to support checkpointing. + */ + async setupHiddenGitRepository() { + const historyDir = path.join(this.projectRoot, historyDirName); + const repoDir = path.join(historyDir, 'repository'); + + await fs.mkdir(repoDir, { recursive: true }); + const repoInstance: SimpleGit = simpleGit(repoDir); + const isRepoDefined = await repoInstance.checkIsRepo( + CheckRepoActions.IS_REPO_ROOT, + ); + if (!isRepoDefined) { + await repoInstance.init(); + try { + await repoInstance.raw([ + 'worktree', + 'add', + this.projectRoot, + '--force', + ]); + } catch (error) { + console.log('Failed to add worktree:', error); + } + } + + const visibileGitIgnorePath = path.join(this.projectRoot, '.gitignore'); + const hiddenGitIgnorePath = path.join(repoDir, '.gitignore'); + + let visibileGitIgnoreContent = ``; + try { + visibileGitIgnoreContent = await fs.readFile( + visibileGitIgnorePath, + 'utf-8', + ); + } catch (error) { + if (isNodeError(error) && error.code !== 'ENOENT') { + throw error; + } + } + + await fs.writeFile(hiddenGitIgnorePath, visibileGitIgnoreContent); + + if (!visibileGitIgnoreContent.includes(historyDirName)) { + const updatedContent = `${visibileGitIgnoreContent}\n# Gemini CLI history directory\n${historyDirName}\n`; + await fs.writeFile(visibileGitIgnorePath, updatedContent); + } + + const commit = await repoInstance.raw([ + 'rev-list', + '--all', + '--max-count=1', + ]); + if (!commit) { + await repoInstance.add(hiddenGitIgnorePath); + + await repoInstance.commit('Initial commit'); + } + } + + private get hiddenGitRepository(): SimpleGit { + const historyDir = path.join(this.projectRoot, historyDirName); + const repoDir = path.join(historyDir, 'repository'); + return simpleGit(this.projectRoot).env({ + GIT_DIR: path.join(repoDir, '.git'), + GIT_WORK_TREE: this.projectRoot, + }); + } + + async getCurrentCommitHash(): Promise<string> { + const hash = await this.hiddenGitRepository.raw('rev-parse', 'HEAD'); + return hash.trim(); + } + + async createFileSnapshot(message: string): Promise<string> { + const repo = this.hiddenGitRepository; + await repo.add('.'); + const commitResult = await repo.commit(message); + return commitResult.commit; + } + + async restoreProjectFromSnapshot(commitHash: string): Promise<void> { + const repo = this.hiddenGitRepository; + await repo.raw(['restore', '--source', commitHash, '.']); + } +} |
