summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--packages/cli/src/ui/hooks/useGeminiStream.ts10
-rw-r--r--packages/cli/src/ui/hooks/useToolScheduler.ts70
-rw-r--r--packages/server/src/core/client.ts50
-rw-r--r--packages/server/src/core/turn.ts14
4 files changed, 52 insertions, 92 deletions
diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts
index 324a4ffa..d3ecad95 100644
--- a/packages/cli/src/ui/hooks/useGeminiStream.ts
+++ b/packages/cli/src/ui/hooks/useGeminiStream.ts
@@ -383,6 +383,7 @@ export const useGeminiStream = (
toolCallRequests.push(event.value);
} else if (event.type === ServerGeminiEventType.UserCancelled) {
handleUserCancelledEvent(userMessageTimestamp);
+ cancel();
return StreamProcessingStatus.UserCancelled;
} else if (event.type === ServerGeminiEventType.Error) {
handleErrorEvent(event.value, userMessageTimestamp);
@@ -393,12 +394,9 @@ export const useGeminiStream = (
return StreamProcessingStatus.Completed;
};
- const streamingState: StreamingState = isResponding
- ? StreamingState.Responding
- : pendingToolCalls?.tools.some(
- (t) => t.status === ToolCallStatus.Confirming,
- )
- ? StreamingState.WaitingForConfirmation
+ const streamingState: StreamingState =
+ isResponding || toolCalls.some((t) => t.status === 'awaiting_approval')
+ ? StreamingState.Responding
: StreamingState.Idle;
const submitQuery = useCallback(
diff --git a/packages/cli/src/ui/hooks/useToolScheduler.ts b/packages/cli/src/ui/hooks/useToolScheduler.ts
index fde632df..e14241b6 100644
--- a/packages/cli/src/ui/hooks/useToolScheduler.ts
+++ b/packages/cli/src/ui/hooks/useToolScheduler.ts
@@ -184,61 +184,55 @@ export function useToolScheduler(
useEffect(() => {
// effect for executing scheduled tool calls
- const scheduledCalls = toolCalls.filter((t) => t.status === 'scheduled');
- const awaitingConfirmation = toolCalls.some(
- (t) => t.status === 'awaiting_approval',
- );
- if (!awaitingConfirmation && scheduledCalls.length) {
- scheduledCalls.forEach(async (c) => {
+ if (toolCalls.every((t) => t.status === 'scheduled')) {
+ toolCalls.forEach((c) => {
const callId = c.request.callId;
- try {
- setToolCalls(setStatus(c.request.callId, 'executing'));
- const result = await c.tool.execute(
- c.request.args,
- abortController.signal,
- );
- const functionResponse: Part = {
- functionResponse: {
- name: c.request.name,
- id: callId,
- response: { output: result.llmContent },
- },
- };
- const response: ToolCallResponseInfo = {
- callId,
- responsePart: functionResponse,
- resultDisplay: result.returnDisplay,
- error: undefined,
- };
- setToolCalls(setStatus(callId, 'success', response));
- } catch (e: unknown) {
- setToolCalls(
- setStatus(
+ setToolCalls(setStatus(c.request.callId, 'executing'));
+ c.tool
+ .execute(c.request.args, abortController.signal)
+ .then((result) => {
+ const functionResponse: Part = {
+ functionResponse: {
+ name: c.request.name,
+ id: callId,
+ response: { output: result.llmContent },
+ },
+ };
+ const response: ToolCallResponseInfo = {
callId,
- 'error',
- toolErrorResponse(
- c.request,
- e instanceof Error ? e : new Error(String(e)),
+ responsePart: functionResponse,
+ resultDisplay: result.returnDisplay,
+ error: undefined,
+ };
+ setToolCalls(setStatus(callId, 'success', response));
+ })
+ .catch((e) =>
+ setToolCalls(
+ setStatus(
+ callId,
+ 'error',
+ toolErrorResponse(
+ c.request,
+ e instanceof Error ? e : new Error(String(e)),
+ ),
),
),
);
- }
});
}
}, [toolCalls, toolRegistry, abortController.signal]);
useEffect(() => {
- const completedTools = toolCalls.filter(
+ const allDone = toolCalls.every(
(t) =>
t.status === 'success' ||
t.status === 'error' ||
t.status === 'cancelled',
);
- const allDone = completedTools.length === toolCalls.length;
if (toolCalls.length && allDone) {
- onComplete(completedTools);
setToolCalls([]);
- setAbortController(new AbortController());
+ onComplete(toolCalls);
+ setAbortController(() => new AbortController());
}
}, [toolCalls, onComplete]);
diff --git a/packages/server/src/core/client.ts b/packages/server/src/core/client.ts
index 489e2a0b..85850da8 100644
--- a/packages/server/src/core/client.ts
+++ b/packages/server/src/core/client.ts
@@ -16,7 +16,7 @@ import {
} from '@google/genai';
import process from 'node:process';
import { getFolderStructure } from '../utils/getFolderStructure.js';
-import { Turn, ServerGeminiStreamEvent, GeminiEventType } from './turn.js';
+import { Turn, ServerGeminiStreamEvent } from './turn.js';
import { Config } from '../config/config.js';
import { getCoreSystemPrompt } from './prompts.js';
import { ReadManyFilesTool } from '../tools/read-many-files.js';
@@ -153,43 +153,23 @@ export class GeminiClient {
chat: Chat,
request: PartListUnion,
signal?: AbortSignal,
+ turns: number = this.MAX_TURNS,
): AsyncGenerator<ServerGeminiStreamEvent> {
- let turns = 0;
- while (turns < this.MAX_TURNS) {
- turns++;
- const turn = new Turn(chat);
- const resultStream = turn.run(request, signal);
- let seenError = false;
- for await (const event of resultStream) {
- seenError =
- seenError === false ? false : event.type === GeminiEventType.Error;
- yield event;
- }
-
- const confirmations = turn.getConfirmationDetails();
- if (confirmations.length > 0) {
- break;
- }
-
- const fnResponses = turn.getFunctionResponses();
- if (fnResponses.length === 0) {
- const nextSpeakerCheck = await checkNextSpeaker(chat, this);
- if (nextSpeakerCheck?.next_speaker === 'model') {
- request = [{ text: 'Please continue.' }];
- continue;
- } else {
- break;
- }
- }
- request = fnResponses;
+ if (!turns) {
+ return;
+ }
- if (seenError) {
- // We saw an error, lets stop processing to prevent unexpected consequences.
- break;
- }
+ const turn = new Turn(chat);
+ const resultStream = turn.run(request, signal);
+ for await (const event of resultStream) {
+ yield event;
}
- if (turns >= this.MAX_TURNS) {
- console.warn('sendMessageStream: Reached maximum tool call turns limit.');
+ if (!turn.pendingToolCalls.length) {
+ const nextSpeakerCheck = await checkNextSpeaker(chat, this);
+ if (nextSpeakerCheck?.next_speaker === 'model') {
+ const nextRequest = [{ text: 'Please continue.' }];
+ return this.sendMessageStream(chat, nextRequest, signal, turns - 1);
+ }
}
}
diff --git a/packages/server/src/core/turn.ts b/packages/server/src/core/turn.ts
index 7b2a96f9..38932041 100644
--- a/packages/server/src/core/turn.ts
+++ b/packages/server/src/core/turn.ts
@@ -106,19 +106,15 @@ export type ServerGeminiStreamEvent =
// A turn manages the agentic loop turn within the server context.
export class Turn {
- private pendingToolCalls: Array<{
+ readonly pendingToolCalls: Array<{
callId: string;
name: string;
args: Record<string, unknown>;
}>;
- private fnResponses: Part[];
- private confirmationDetails: ToolCallConfirmationDetails[];
private debugResponses: GenerateContentResponse[];
constructor(private readonly chat: Chat) {
this.pendingToolCalls = [];
- this.fnResponses = [];
- this.confirmationDetails = [];
this.debugResponses = [];
}
// The run method yields simpler events suitable for server logic
@@ -182,14 +178,6 @@ export class Turn {
return { type: GeminiEventType.ToolCallRequest, value };
}
- getConfirmationDetails(): ToolCallConfirmationDetails[] {
- return this.confirmationDetails;
- }
-
- getFunctionResponses(): Part[] {
- return this.fnResponses;
- }
-
getDebugResponses(): GenerateContentResponse[] {
return this.debugResponses;
}