summaryrefslogtreecommitdiff
path: root/packages/cli/src/core/turn.ts
diff options
context:
space:
mode:
authorJaana Dogan <[email protected]>2025-04-18 23:11:33 -0700
committerGitHub <[email protected]>2025-04-18 23:11:33 -0700
commit24371a39546a7802ce612c76f5a250b35a739acc (patch)
tree5e80e2b856247b76e95b84b9ec0f5ed859a4f53d /packages/cli/src/core/turn.ts
parent65e8e3ed1f759a273f2a8f667e8af4bbeeccaa36 (diff)
Take the turn management out of GeminiClient (#42)
Diffstat (limited to 'packages/cli/src/core/turn.ts')
-rw-r--r--packages/cli/src/core/turn.ts233
1 files changed, 233 insertions, 0 deletions
diff --git a/packages/cli/src/core/turn.ts b/packages/cli/src/core/turn.ts
new file mode 100644
index 00000000..e8c4ef78
--- /dev/null
+++ b/packages/cli/src/core/turn.ts
@@ -0,0 +1,233 @@
+import {
+ Part,
+ Chat,
+ PartListUnion,
+ GenerateContentResponse,
+ FunctionCall,
+} from '@google/genai';
+import {
+ type ToolCallConfirmationDetails,
+ ToolCallStatus,
+ ToolCallEvent,
+} from '../ui/types.js';
+import { ToolResult } from '../tools/tools.js';
+import { toolRegistry } from '../tools/tool-registry.js';
+import { GeminiEventType, GeminiStream } from './gemini-stream.js';
+
+export type ToolExecutionOutcome = {
+ callId: string;
+ name: string;
+ args: Record<string, never>;
+ result?: ToolResult;
+ error?: Error;
+ confirmationDetails?: ToolCallConfirmationDetails;
+};
+
+// TODO(jbd): Move ToolExecutionOutcome to somewhere else?
+
+// A turn manages the agentic loop turn.
+// Turn.run emits throught the turn events that could be used
+// as immediate feedback to the user.
+export class Turn {
+ private readonly chat: Chat;
+ private pendingToolCalls: Array<{
+ callId: string;
+ name: string;
+ args: Record<string, never>;
+ }>;
+ private fnResponses: Part[];
+ private debugResponses: GenerateContentResponse[];
+
+ constructor(chat: Chat) {
+ this.chat = chat;
+ this.pendingToolCalls = [];
+ this.fnResponses = [];
+ this.debugResponses = [];
+ }
+
+ async *run(req: PartListUnion, signal?: AbortSignal): GeminiStream {
+ const responseStream = await this.chat.sendMessageStream({
+ message: req,
+ });
+ for await (const resp of responseStream) {
+ this.debugResponses.push(resp);
+ if (signal?.aborted) {
+ throw this.abortError();
+ }
+ if (resp.text) {
+ yield {
+ type: GeminiEventType.Content,
+ value: resp.text,
+ };
+ continue;
+ }
+ if (!resp.functionCalls) {
+ continue;
+ }
+ for (const fnCall of resp.functionCalls) {
+ for await (const event of this.handlePendingFunctionCall(fnCall)) {
+ yield event;
+ }
+ }
+
+ // Create promises to be able to wait for executions to complete.
+ const toolPromises = this.pendingToolCalls.map(
+ async (pendingToolCall) => {
+ const tool = toolRegistry.getTool(pendingToolCall.name);
+ if (!tool) {
+ return {
+ ...pendingToolCall,
+ error: new Error(
+ `Tool "${pendingToolCall.name}" not found or is not registered.`,
+ ),
+ };
+ }
+ const shouldConfirm = await tool.shouldConfirmExecute(
+ pendingToolCall.args,
+ );
+ if (shouldConfirm) {
+ return {
+ // TODO(jbd): Should confirm isn't confirmation details.
+ ...pendingToolCall,
+ confirmationDetails: shouldConfirm,
+ };
+ }
+ const result = await tool.execute(pendingToolCall.args);
+ return { ...pendingToolCall, result };
+ },
+ );
+ const outcomes = await Promise.all(toolPromises);
+ for await (const event of this.handleToolOutcomes(outcomes)) {
+ yield event;
+ }
+ this.pendingToolCalls = [];
+
+ // TODO(jbd): Make it harder for the caller to ignore the
+ // buffered function responses.
+ this.fnResponses = this.buildFunctionResponses(outcomes);
+ }
+ }
+
+ private async *handlePendingFunctionCall(fnCall: FunctionCall): GeminiStream {
+ const callId =
+ fnCall.id ??
+ `${fnCall.name}-${Date.now()}-${Math.random().toString(16).slice(2)}`;
+ // TODO(jbd): replace with uuid.
+ const name = fnCall.name || 'undefined_tool_name';
+ const args = (fnCall.args || {}) as Record<string, never>;
+
+ this.pendingToolCalls.push({ callId, name, args });
+ const value: ToolCallEvent = {
+ type: 'tool_call',
+ status: ToolCallStatus.Pending,
+ callId,
+ name,
+ args,
+ resultDisplay: undefined,
+ confirmationDetails: undefined,
+ };
+ yield {
+ type: GeminiEventType.ToolCallInfo,
+ value,
+ };
+ }
+
+ private async *handleToolOutcomes(
+ outcomes: ToolExecutionOutcome[],
+ ): GeminiStream {
+ for (const outcome of outcomes) {
+ const { callId, name, args, result, error, confirmationDetails } =
+ outcome;
+ if (error) {
+ // TODO(jbd): Error handling needs a cleanup.
+ const errorMessage = error?.message || String(error);
+ yield {
+ type: GeminiEventType.Content,
+ value: `[Error invoking tool ${name}: ${errorMessage}]`,
+ };
+ return;
+ }
+ if (
+ result &&
+ typeof result === 'object' &&
+ result !== null &&
+ 'error' in result
+ ) {
+ const errorMessage = String(result.error);
+ yield {
+ type: GeminiEventType.Content,
+ value: `[Error executing tool ${name}: ${errorMessage}]`,
+ };
+ return;
+ }
+ const status = confirmationDetails
+ ? ToolCallStatus.Confirming
+ : ToolCallStatus.Invoked;
+ const value: ToolCallEvent = {
+ type: 'tool_call',
+ status,
+ callId,
+ name,
+ args,
+ resultDisplay: result?.returnDisplay,
+ confirmationDetails,
+ };
+ yield {
+ type: GeminiEventType.ToolCallInfo,
+ value,
+ };
+ }
+ }
+
+ private buildFunctionResponses(outcomes: ToolExecutionOutcome[]): Part[] {
+ return outcomes.map((outcome: ToolExecutionOutcome): Part => {
+ const { name, result, error } = outcome;
+ const output = { output: result?.llmContent };
+ let fnResponse: Record<string, unknown>;
+
+ if (error) {
+ const errorMessage = error?.message || String(error);
+ fnResponse = {
+ error: `Invocation failed: ${errorMessage}`,
+ };
+ console.error(`[Turn] Critical error invoking tool ${name}:`, error);
+ } else if (
+ result &&
+ typeof result === 'object' &&
+ result !== null &&
+ 'error' in result
+ ) {
+ fnResponse = output;
+ console.warn(
+ `[Turn] Tool ${name} returned an error structure:`,
+ result.error,
+ );
+ } else {
+ fnResponse = output;
+ }
+
+ return {
+ functionResponse: {
+ name,
+ id: outcome.callId,
+ response: fnResponse,
+ },
+ };
+ });
+ }
+
+ private abortError(): Error {
+ // TODO(jbd): Move it out of this class.
+ const error = new Error('Request cancelled by user during stream.');
+ error.name = 'AbortError';
+ throw error;
+ }
+
+ getFunctionResponses(): Part[] {
+ return this.fnResponses;
+ }
+
+ getDebugResponses(): GenerateContentResponse[] {
+ return this.debugResponses;
+ }
+}