summaryrefslogtreecommitdiff
path: root/packages/cli/src/ui/hooks/useReactToolScheduler.ts
diff options
context:
space:
mode:
authorN. Taylor Mullen <[email protected]>2025-06-01 14:16:24 -0700
committerGitHub <[email protected]>2025-06-01 14:16:24 -0700
commitf2a8d39f42ae88c1b7a9a5a75854363a53444ca2 (patch)
tree181d8eb3f1b1602f985fba4d2522b06c6c4f2eb6 /packages/cli/src/ui/hooks/useReactToolScheduler.ts
parentedc12e416d0b9daf24ede50cb18b012cb2b6e18a (diff)
refactor: Centralize tool scheduling logic and simplify React hook (#670)
Diffstat (limited to 'packages/cli/src/ui/hooks/useReactToolScheduler.ts')
-rw-r--r--packages/cli/src/ui/hooks/useReactToolScheduler.ts301
1 files changed, 301 insertions, 0 deletions
diff --git a/packages/cli/src/ui/hooks/useReactToolScheduler.ts b/packages/cli/src/ui/hooks/useReactToolScheduler.ts
new file mode 100644
index 00000000..12333d92
--- /dev/null
+++ b/packages/cli/src/ui/hooks/useReactToolScheduler.ts
@@ -0,0 +1,301 @@
+/**
+ * @license
+ * Copyright 2025 Google LLC
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+import {
+ Config,
+ ToolCallRequestInfo,
+ ExecutingToolCall,
+ ScheduledToolCall,
+ ValidatingToolCall,
+ WaitingToolCall,
+ CompletedToolCall,
+ CancelledToolCall,
+ CoreToolScheduler,
+ OutputUpdateHandler,
+ AllToolCallsCompleteHandler,
+ ToolCallsUpdateHandler,
+ Tool,
+ ToolCall,
+ Status as CoreStatus,
+} from '@gemini-code/core';
+import { useCallback, useEffect, useState, useRef } from 'react';
+import {
+ HistoryItemToolGroup,
+ IndividualToolCallDisplay,
+ ToolCallStatus,
+ HistoryItemWithoutId,
+} from '../types.js';
+
+export type ScheduleFn = (
+ request: ToolCallRequestInfo | ToolCallRequestInfo[],
+) => void;
+export type CancelFn = (reason?: string) => void;
+export type MarkToolsAsSubmittedFn = (callIds: string[]) => void;
+
+export type TrackedScheduledToolCall = ScheduledToolCall & {
+ responseSubmittedToGemini?: boolean;
+};
+export type TrackedValidatingToolCall = ValidatingToolCall & {
+ responseSubmittedToGemini?: boolean;
+};
+export type TrackedWaitingToolCall = WaitingToolCall & {
+ responseSubmittedToGemini?: boolean;
+};
+export type TrackedExecutingToolCall = ExecutingToolCall & {
+ responseSubmittedToGemini?: boolean;
+};
+export type TrackedCompletedToolCall = CompletedToolCall & {
+ responseSubmittedToGemini?: boolean;
+};
+export type TrackedCancelledToolCall = CancelledToolCall & {
+ responseSubmittedToGemini?: boolean;
+};
+
+export type TrackedToolCall =
+ | TrackedScheduledToolCall
+ | TrackedValidatingToolCall
+ | TrackedWaitingToolCall
+ | TrackedExecutingToolCall
+ | TrackedCompletedToolCall
+ | TrackedCancelledToolCall;
+
+export function useReactToolScheduler(
+ onComplete: (tools: CompletedToolCall[]) => void,
+ config: Config,
+ setPendingHistoryItem: React.Dispatch<
+ React.SetStateAction<HistoryItemWithoutId | null>
+ >,
+): [TrackedToolCall[], ScheduleFn, CancelFn, MarkToolsAsSubmittedFn] {
+ const [toolCallsForDisplay, setToolCallsForDisplay] = useState<
+ TrackedToolCall[]
+ >([]);
+ const schedulerRef = useRef<CoreToolScheduler | null>(null);
+
+ useEffect(() => {
+ const outputUpdateHandler: OutputUpdateHandler = (
+ toolCallId,
+ outputChunk,
+ ) => {
+ setPendingHistoryItem((prevItem) => {
+ if (prevItem?.type === 'tool_group') {
+ return {
+ ...prevItem,
+ tools: prevItem.tools.map((toolDisplay) =>
+ toolDisplay.callId === toolCallId &&
+ toolDisplay.status === ToolCallStatus.Executing
+ ? { ...toolDisplay, resultDisplay: outputChunk }
+ : toolDisplay,
+ ),
+ };
+ }
+ return prevItem;
+ });
+
+ setToolCallsForDisplay((prevCalls) =>
+ prevCalls.map((tc) => {
+ if (tc.request.callId === toolCallId && tc.status === 'executing') {
+ const executingTc = tc as TrackedExecutingToolCall;
+ return { ...executingTc, liveOutput: outputChunk };
+ }
+ return tc;
+ }),
+ );
+ };
+
+ const allToolCallsCompleteHandler: AllToolCallsCompleteHandler = (
+ completedToolCalls,
+ ) => {
+ onComplete(completedToolCalls);
+ };
+
+ const toolCallsUpdateHandler: ToolCallsUpdateHandler = (
+ updatedCoreToolCalls: ToolCall[],
+ ) => {
+ setToolCallsForDisplay((prevTrackedCalls) =>
+ updatedCoreToolCalls.map((coreTc) => {
+ const existingTrackedCall = prevTrackedCalls.find(
+ (ptc) => ptc.request.callId === coreTc.request.callId,
+ );
+ const newTrackedCall: TrackedToolCall = {
+ ...coreTc,
+ responseSubmittedToGemini:
+ existingTrackedCall?.responseSubmittedToGemini ?? false,
+ } as TrackedToolCall;
+ return newTrackedCall;
+ }),
+ );
+ };
+
+ schedulerRef.current = new CoreToolScheduler({
+ toolRegistry: config.getToolRegistry(),
+ outputUpdateHandler,
+ onAllToolCallsComplete: allToolCallsCompleteHandler,
+ onToolCallsUpdate: toolCallsUpdateHandler,
+ });
+ }, [config, onComplete, setPendingHistoryItem]);
+
+ const schedule: ScheduleFn = useCallback(
+ async (request: ToolCallRequestInfo | ToolCallRequestInfo[]) => {
+ schedulerRef.current?.schedule(request);
+ },
+ [],
+ );
+
+ const cancel: CancelFn = useCallback((reason: string = 'unspecified') => {
+ schedulerRef.current?.cancelAll(reason);
+ }, []);
+
+ const markToolsAsSubmitted: MarkToolsAsSubmittedFn = useCallback(
+ (callIdsToMark: string[]) => {
+ setToolCallsForDisplay((prevCalls) =>
+ prevCalls.map((tc) =>
+ callIdsToMark.includes(tc.request.callId)
+ ? { ...tc, responseSubmittedToGemini: true }
+ : tc,
+ ),
+ );
+ },
+ [],
+ );
+
+ return [toolCallsForDisplay, schedule, cancel, markToolsAsSubmitted];
+}
+
+/**
+ * Maps a CoreToolScheduler status to the UI's ToolCallStatus enum.
+ */
+function mapCoreStatusToDisplayStatus(coreStatus: CoreStatus): ToolCallStatus {
+ switch (coreStatus) {
+ case 'validating':
+ return ToolCallStatus.Executing;
+ case 'awaiting_approval':
+ return ToolCallStatus.Confirming;
+ case 'executing':
+ return ToolCallStatus.Executing;
+ case 'success':
+ return ToolCallStatus.Success;
+ case 'cancelled':
+ return ToolCallStatus.Canceled;
+ case 'error':
+ return ToolCallStatus.Error;
+ case 'scheduled':
+ return ToolCallStatus.Pending;
+ default: {
+ const exhaustiveCheck: never = coreStatus;
+ console.warn(`Unknown core status encountered: ${exhaustiveCheck}`);
+ return ToolCallStatus.Error;
+ }
+ }
+}
+
+/**
+ * Transforms `TrackedToolCall` objects into `HistoryItemToolGroup` objects for UI display.
+ */
+export function mapToDisplay(
+ toolOrTools: TrackedToolCall[] | TrackedToolCall,
+): HistoryItemToolGroup {
+ const toolCalls = Array.isArray(toolOrTools) ? toolOrTools : [toolOrTools];
+
+ const toolDisplays = toolCalls.map(
+ (trackedCall): IndividualToolCallDisplay => {
+ let displayName = trackedCall.request.name;
+ let description = '';
+ let renderOutputAsMarkdown = false;
+
+ const currentToolInstance =
+ 'tool' in trackedCall && trackedCall.tool
+ ? (trackedCall as { tool: Tool }).tool
+ : undefined;
+
+ if (currentToolInstance) {
+ displayName = currentToolInstance.displayName;
+ description = currentToolInstance.getDescription(
+ trackedCall.request.args,
+ );
+ renderOutputAsMarkdown = currentToolInstance.isOutputMarkdown;
+ }
+
+ if (trackedCall.status === 'error') {
+ description = '';
+ }
+
+ const baseDisplayProperties: Omit<
+ IndividualToolCallDisplay,
+ 'status' | 'resultDisplay' | 'confirmationDetails'
+ > = {
+ callId: trackedCall.request.callId,
+ name: displayName,
+ description,
+ renderOutputAsMarkdown,
+ };
+
+ switch (trackedCall.status) {
+ case 'success':
+ return {
+ ...baseDisplayProperties,
+ status: mapCoreStatusToDisplayStatus(trackedCall.status),
+ resultDisplay: trackedCall.response.resultDisplay,
+ confirmationDetails: undefined,
+ };
+ case 'error':
+ return {
+ ...baseDisplayProperties,
+ name: currentToolInstance?.displayName ?? trackedCall.request.name,
+ status: mapCoreStatusToDisplayStatus(trackedCall.status),
+ resultDisplay: trackedCall.response.resultDisplay,
+ confirmationDetails: undefined,
+ };
+ case 'cancelled':
+ return {
+ ...baseDisplayProperties,
+ status: mapCoreStatusToDisplayStatus(trackedCall.status),
+ resultDisplay: trackedCall.response.resultDisplay,
+ confirmationDetails: undefined,
+ };
+ case 'awaiting_approval':
+ return {
+ ...baseDisplayProperties,
+ status: mapCoreStatusToDisplayStatus(trackedCall.status),
+ resultDisplay: undefined,
+ confirmationDetails: trackedCall.confirmationDetails,
+ };
+ case 'executing':
+ return {
+ ...baseDisplayProperties,
+ status: mapCoreStatusToDisplayStatus(trackedCall.status),
+ resultDisplay:
+ (trackedCall as TrackedExecutingToolCall).liveOutput ?? undefined,
+ confirmationDetails: undefined,
+ };
+ case 'validating': // Fallthrough
+ case 'scheduled':
+ return {
+ ...baseDisplayProperties,
+ status: mapCoreStatusToDisplayStatus(trackedCall.status),
+ resultDisplay: undefined,
+ confirmationDetails: undefined,
+ };
+ default: {
+ const exhaustiveCheck: never = trackedCall;
+ return {
+ callId: (exhaustiveCheck as TrackedToolCall).request.callId,
+ name: 'Unknown Tool',
+ description: 'Encountered an unknown tool call state.',
+ status: ToolCallStatus.Error,
+ resultDisplay: 'Unknown tool call state',
+ confirmationDetails: undefined,
+ renderOutputAsMarkdown: false,
+ };
+ }
+ }
+ },
+ );
+
+ return {
+ type: 'tool_group',
+ tools: toolDisplays,
+ };
+}