diff options
Diffstat (limited to 'packages/cli/src/ui/hooks/useGeminiStream.ts')
| -rw-r--r-- | packages/cli/src/ui/hooks/useGeminiStream.ts | 119 |
1 files changed, 116 insertions, 3 deletions
diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index 725d8737..7d0fe375 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -7,6 +7,7 @@ import { useState, useRef, useCallback, useEffect, useMemo } from 'react'; import { useInput } from 'ink'; import { + Config, GeminiClient, GeminiEventType as ServerGeminiEventType, ServerGeminiStreamEvent as GeminiEvent, @@ -14,14 +15,15 @@ import { ServerGeminiErrorEvent as ErrorEvent, getErrorMessage, isNodeError, - Config, MessageSenderType, ToolCallRequestInfo, logUserPrompt, + GitService, } from '@gemini-cli/core'; import { type Part, type PartListUnion } from '@google/genai'; import { StreamingState, + HistoryItem, HistoryItemWithoutId, HistoryItemToolGroup, MessageType, @@ -35,6 +37,8 @@ import { findLastSafeSplitPoint } from '../utils/markdownUtilities.js'; import { useStateAndRef } from './useStateAndRef.js'; import { UseHistoryManagerReturn } from './useHistoryManager.js'; import { useLogger } from './useLogger.js'; +import { promises as fs } from 'fs'; +import path from 'path'; import { useReactToolScheduler, mapToDisplay as mapTrackedToolCallsToDisplay, @@ -68,13 +72,16 @@ enum StreamProcessingStatus { */ export const useGeminiStream = ( geminiClient: GeminiClient | null, + history: HistoryItem[], addItem: UseHistoryManagerReturn['addItem'], setShowHelp: React.Dispatch<React.SetStateAction<boolean>>, config: Config, onDebugMessage: (message: string) => void, handleSlashCommand: ( cmd: PartListUnion, - ) => import('./slashCommandProcessor.js').SlashCommandActionReturn | boolean, + ) => Promise< + import('./slashCommandProcessor.js').SlashCommandActionReturn | boolean + >, shellModeActive: boolean, ) => { const [initError, setInitError] = useState<string | null>(null); @@ -84,6 +91,12 @@ export const useGeminiStream = ( useStateAndRef<HistoryItemWithoutId | null>(null); const logger = useLogger(); const { startNewTurn, addUsage } = useSessionStats(); + const gitService = useMemo(() => { + if (!config.getProjectRoot()) { + return; + } + return new GitService(config.getProjectRoot()); + }, [config]); const [toolCalls, scheduleToolCalls, markToolsAsSubmitted] = useReactToolScheduler( @@ -178,7 +191,7 @@ export const useGeminiStream = ( await logger?.logMessage(MessageSenderType.USER, trimmedQuery); // Handle UI-only commands first - const slashCommandResult = handleSlashCommand(trimmedQuery); + const slashCommandResult = await handleSlashCommand(trimmedQuery); if (typeof slashCommandResult === 'boolean' && slashCommandResult) { // Command was handled, and it doesn't require a tool call from here return { queryToSend: null, shouldProceed: false }; @@ -605,6 +618,106 @@ export const useGeminiStream = ( pendingToolCallGroupDisplay, ].filter((i) => i !== undefined && i !== null); + useEffect(() => { + const saveRestorableToolCalls = async () => { + if (!config.getCheckpointEnabled()) { + return; + } + const restorableToolCalls = toolCalls.filter( + (toolCall) => + (toolCall.request.name === 'replace' || + toolCall.request.name === 'write_file') && + toolCall.status === 'awaiting_approval', + ); + + if (restorableToolCalls.length > 0) { + const checkpointDir = config.getGeminiDir() + ? path.join(config.getGeminiDir(), 'checkpoints') + : undefined; + + if (!checkpointDir) { + return; + } + + try { + await fs.mkdir(checkpointDir, { recursive: true }); + } catch (error) { + if (!isNodeError(error) || error.code !== 'EEXIST') { + onDebugMessage( + `Failed to create checkpoint directory: ${getErrorMessage(error)}`, + ); + return; + } + } + + for (const toolCall of restorableToolCalls) { + const filePath = toolCall.request.args['file_path'] as string; + if (!filePath) { + onDebugMessage( + `Skipping restorable tool call due to missing file_path: ${toolCall.request.name}`, + ); + continue; + } + + try { + let commitHash = await gitService?.createFileSnapshot( + `Snapshot for ${toolCall.request.name}`, + ); + + if (!commitHash) { + commitHash = await gitService?.getCurrentCommitHash(); + } + + if (!commitHash) { + onDebugMessage( + `Failed to create snapshot for ${filePath}. Skipping restorable tool call.`, + ); + continue; + } + + const timestamp = new Date() + .toISOString() + .replace(/:/g, '-') + .replace(/\./g, '_'); + const toolName = toolCall.request.name; + const fileName = path.basename(filePath); + const toolCallWithSnapshotFileName = `${timestamp}-${fileName}-${toolName}.json`; + const clientHistory = await geminiClient?.getHistory(); + const toolCallWithSnapshotFilePath = path.join( + checkpointDir, + toolCallWithSnapshotFileName, + ); + + await fs.writeFile( + toolCallWithSnapshotFilePath, + JSON.stringify( + { + history, + clientHistory, + toolCall: { + name: toolCall.request.name, + args: toolCall.request.args, + }, + commitHash, + filePath, + }, + null, + 2, + ), + ); + } catch (error) { + onDebugMessage( + `Failed to write restorable tool call file: ${getErrorMessage( + error, + )}`, + ); + } + } + } + }; + saveRestorableToolCalls(); + }, [toolCalls, config, onDebugMessage, gitService, history, geminiClient]); + return { streamingState, submitQuery, |
