summaryrefslogtreecommitdiff
path: root/packages/cli/src/ui/hooks/useToolScheduler.ts
diff options
context:
space:
mode:
Diffstat (limited to 'packages/cli/src/ui/hooks/useToolScheduler.ts')
-rw-r--r--packages/cli/src/ui/hooks/useToolScheduler.ts90
1 files changed, 69 insertions, 21 deletions
diff --git a/packages/cli/src/ui/hooks/useToolScheduler.ts b/packages/cli/src/ui/hooks/useToolScheduler.ts
index a5770d36..1bb44133 100644
--- a/packages/cli/src/ui/hooks/useToolScheduler.ts
+++ b/packages/cli/src/ui/hooks/useToolScheduler.ts
@@ -20,6 +20,12 @@ import {
ToolCallStatus,
} from '../types.js';
+type ValidatingToolCall = {
+ status: 'validating';
+ request: ToolCallRequestInfo;
+ tool: Tool;
+};
+
type ScheduledToolCall = {
status: 'scheduled';
request: ToolCallRequestInfo;
@@ -62,6 +68,7 @@ type WaitingToolCall = {
export type Status = ToolCall['status'];
export type ToolCall =
+ | ValidatingToolCall
| ScheduledToolCall
| ErroredToolCall
| SuccessfulToolCall
@@ -99,9 +106,12 @@ export function useToolScheduler(
'Cannot schedule tool calls while other tool calls are running',
);
}
- const requests = Array.isArray(request) ? request : [request];
- const newCalls: ToolCall[] = await Promise.all(
- requests.map(async (r): Promise<ToolCall> => {
+ const requestsToProcess = Array.isArray(request) ? request : [request];
+
+ // Step 1: Create initial calls with 'validating' status (or 'error' if tool not found)
+ // and add them to the state immediately to make the UI busy.
+ const initialNewCalls: ToolCall[] = requestsToProcess.map(
+ (r): ToolCall => {
const tool = toolRegistry.getTool(r.name);
if (!tool) {
return {
@@ -113,16 +123,27 @@ export function useToolScheduler(
),
};
}
+ // Set to 'validating' immediately. This will make streamingState 'Responding'.
+ return { status: 'validating', request: r, tool };
+ },
+ );
+ setToolCalls((prevCalls) => prevCalls.concat(initialNewCalls));
+ // Step 2: Asynchronously check for confirmation and update status for each new call.
+ initialNewCalls.forEach(async (initialCall) => {
+ // If the call was already marked as an error (tool not found), skip further processing.
+ if (initialCall.status !== 'validating') return;
+
+ const { request: r, tool } = initialCall;
+ try {
const userApproval = await tool.shouldConfirmExecute(r.args);
if (userApproval) {
- return {
- status: 'awaiting_approval',
- request: r,
- tool,
- confirmationDetails: {
+ // Confirmation is needed. Update status to 'awaiting_approval'.
+ setToolCalls(
+ setStatus(r.callId, 'awaiting_approval', {
...userApproval,
onConfirm: async (outcome) => {
+ // This onConfirm is triggered by user interaction later.
await userApproval.onConfirm(outcome);
setToolCalls(
outcome === ToolConfirmationOutcome.Cancel
@@ -131,21 +152,30 @@ export function useToolScheduler(
'cancelled',
'User did not allow tool call',
)
- : setStatus(r.callId, 'scheduled'),
+ : // If confirmed, it goes to 'scheduled' to be picked up by the execution effect.
+ setStatus(r.callId, 'scheduled'),
);
},
- },
- };
+ }),
+ );
+ } else {
+ // No confirmation needed, move to 'scheduled' for execution.
+ setToolCalls(setStatus(r.callId, 'scheduled'));
}
-
- return {
- status: 'scheduled',
- request: r,
- tool,
- };
- }),
- );
- setToolCalls((t) => t.concat(newCalls));
+ } catch (e) {
+ // Handle errors from tool.shouldConfirmExecute() itself.
+ setToolCalls(
+ setStatus(
+ r.callId,
+ 'error',
+ toolErrorResponse(
+ r,
+ e instanceof Error ? e : new Error(String(e)),
+ ),
+ ),
+ );
+ }
+ });
},
[isRunning, setToolCalls, toolRegistry],
);
@@ -273,7 +303,7 @@ function setStatus(
): (t: ToolCall[]) => ToolCall[];
function setStatus(
targetCallId: string,
- status: 'executing' | 'scheduled',
+ status: 'executing' | 'scheduled' | 'validating',
): (t: ToolCall[]) => ToolCall[];
function setStatus(
targetCallId: string,
@@ -338,6 +368,13 @@ function setStatus(
};
return next;
}
+ case 'validating': {
+ const next: ValidatingToolCall = {
+ ...(t as ValidatingToolCall), // Added type assertion for safety
+ status: 'validating',
+ };
+ return next;
+ }
case 'executing': {
const next: ExecutingToolCall = {
...t,
@@ -373,6 +410,8 @@ const toolErrorResponse = (
function mapStatus(status: Status): ToolCallStatus {
switch (status) {
+ case 'validating':
+ return ToolCallStatus.Executing;
case 'awaiting_approval':
return ToolCallStatus.Confirming;
case 'executing':
@@ -445,6 +484,15 @@ export function mapToDisplay(
status: mapStatus(t.status),
confirmationDetails: undefined,
};
+ case 'validating': // Add this case
+ return {
+ callId: t.request.callId,
+ name: t.tool.displayName,
+ description: t.tool.getDescription(t.request.args),
+ resultDisplay: undefined,
+ status: mapStatus(t.status),
+ confirmationDetails: undefined,
+ };
case 'scheduled':
return {
callId: t.request.callId,