summaryrefslogtreecommitdiff
path: root/packages/cli/src/ui/hooks/useGeminiStream.ts
diff options
context:
space:
mode:
Diffstat (limited to 'packages/cli/src/ui/hooks/useGeminiStream.ts')
-rw-r--r--packages/cli/src/ui/hooks/useGeminiStream.ts119
1 files changed, 116 insertions, 3 deletions
diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts
index 725d8737..7d0fe375 100644
--- a/packages/cli/src/ui/hooks/useGeminiStream.ts
+++ b/packages/cli/src/ui/hooks/useGeminiStream.ts
@@ -7,6 +7,7 @@
import { useState, useRef, useCallback, useEffect, useMemo } from 'react';
import { useInput } from 'ink';
import {
+ Config,
GeminiClient,
GeminiEventType as ServerGeminiEventType,
ServerGeminiStreamEvent as GeminiEvent,
@@ -14,14 +15,15 @@ import {
ServerGeminiErrorEvent as ErrorEvent,
getErrorMessage,
isNodeError,
- Config,
MessageSenderType,
ToolCallRequestInfo,
logUserPrompt,
+ GitService,
} from '@gemini-cli/core';
import { type Part, type PartListUnion } from '@google/genai';
import {
StreamingState,
+ HistoryItem,
HistoryItemWithoutId,
HistoryItemToolGroup,
MessageType,
@@ -35,6 +37,8 @@ import { findLastSafeSplitPoint } from '../utils/markdownUtilities.js';
import { useStateAndRef } from './useStateAndRef.js';
import { UseHistoryManagerReturn } from './useHistoryManager.js';
import { useLogger } from './useLogger.js';
+import { promises as fs } from 'fs';
+import path from 'path';
import {
useReactToolScheduler,
mapToDisplay as mapTrackedToolCallsToDisplay,
@@ -68,13 +72,16 @@ enum StreamProcessingStatus {
*/
export const useGeminiStream = (
geminiClient: GeminiClient | null,
+ history: HistoryItem[],
addItem: UseHistoryManagerReturn['addItem'],
setShowHelp: React.Dispatch<React.SetStateAction<boolean>>,
config: Config,
onDebugMessage: (message: string) => void,
handleSlashCommand: (
cmd: PartListUnion,
- ) => import('./slashCommandProcessor.js').SlashCommandActionReturn | boolean,
+ ) => Promise<
+ import('./slashCommandProcessor.js').SlashCommandActionReturn | boolean
+ >,
shellModeActive: boolean,
) => {
const [initError, setInitError] = useState<string | null>(null);
@@ -84,6 +91,12 @@ export const useGeminiStream = (
useStateAndRef<HistoryItemWithoutId | null>(null);
const logger = useLogger();
const { startNewTurn, addUsage } = useSessionStats();
+ const gitService = useMemo(() => {
+ if (!config.getProjectRoot()) {
+ return;
+ }
+ return new GitService(config.getProjectRoot());
+ }, [config]);
const [toolCalls, scheduleToolCalls, markToolsAsSubmitted] =
useReactToolScheduler(
@@ -178,7 +191,7 @@ export const useGeminiStream = (
await logger?.logMessage(MessageSenderType.USER, trimmedQuery);
// Handle UI-only commands first
- const slashCommandResult = handleSlashCommand(trimmedQuery);
+ const slashCommandResult = await handleSlashCommand(trimmedQuery);
if (typeof slashCommandResult === 'boolean' && slashCommandResult) {
// Command was handled, and it doesn't require a tool call from here
return { queryToSend: null, shouldProceed: false };
@@ -605,6 +618,106 @@ export const useGeminiStream = (
pendingToolCallGroupDisplay,
].filter((i) => i !== undefined && i !== null);
+ useEffect(() => {
+ const saveRestorableToolCalls = async () => {
+ if (!config.getCheckpointEnabled()) {
+ return;
+ }
+ const restorableToolCalls = toolCalls.filter(
+ (toolCall) =>
+ (toolCall.request.name === 'replace' ||
+ toolCall.request.name === 'write_file') &&
+ toolCall.status === 'awaiting_approval',
+ );
+
+ if (restorableToolCalls.length > 0) {
+ const checkpointDir = config.getGeminiDir()
+ ? path.join(config.getGeminiDir(), 'checkpoints')
+ : undefined;
+
+ if (!checkpointDir) {
+ return;
+ }
+
+ try {
+ await fs.mkdir(checkpointDir, { recursive: true });
+ } catch (error) {
+ if (!isNodeError(error) || error.code !== 'EEXIST') {
+ onDebugMessage(
+ `Failed to create checkpoint directory: ${getErrorMessage(error)}`,
+ );
+ return;
+ }
+ }
+
+ for (const toolCall of restorableToolCalls) {
+ const filePath = toolCall.request.args['file_path'] as string;
+ if (!filePath) {
+ onDebugMessage(
+ `Skipping restorable tool call due to missing file_path: ${toolCall.request.name}`,
+ );
+ continue;
+ }
+
+ try {
+ let commitHash = await gitService?.createFileSnapshot(
+ `Snapshot for ${toolCall.request.name}`,
+ );
+
+ if (!commitHash) {
+ commitHash = await gitService?.getCurrentCommitHash();
+ }
+
+ if (!commitHash) {
+ onDebugMessage(
+ `Failed to create snapshot for ${filePath}. Skipping restorable tool call.`,
+ );
+ continue;
+ }
+
+ const timestamp = new Date()
+ .toISOString()
+ .replace(/:/g, '-')
+ .replace(/\./g, '_');
+ const toolName = toolCall.request.name;
+ const fileName = path.basename(filePath);
+ const toolCallWithSnapshotFileName = `${timestamp}-${fileName}-${toolName}.json`;
+ const clientHistory = await geminiClient?.getHistory();
+ const toolCallWithSnapshotFilePath = path.join(
+ checkpointDir,
+ toolCallWithSnapshotFileName,
+ );
+
+ await fs.writeFile(
+ toolCallWithSnapshotFilePath,
+ JSON.stringify(
+ {
+ history,
+ clientHistory,
+ toolCall: {
+ name: toolCall.request.name,
+ args: toolCall.request.args,
+ },
+ commitHash,
+ filePath,
+ },
+ null,
+ 2,
+ ),
+ );
+ } catch (error) {
+ onDebugMessage(
+ `Failed to write restorable tool call file: ${getErrorMessage(
+ error,
+ )}`,
+ );
+ }
+ }
+ }
+ };
+ saveRestorableToolCalls();
+ }, [toolCalls, config, onDebugMessage, gitService, history, geminiClient]);
+
return {
streamingState,
submitQuery,