summaryrefslogtreecommitdiff
path: root/packages/cli/src
diff options
context:
space:
mode:
Diffstat (limited to 'packages/cli/src')
-rw-r--r--packages/cli/src/core/gemini-client.ts412
-rw-r--r--packages/cli/src/core/turn.ts233
2 files changed, 335 insertions, 310 deletions
diff --git a/packages/cli/src/core/gemini-client.ts b/packages/cli/src/core/gemini-client.ts
index be338754..19dba40f 100644
--- a/packages/cli/src/core/gemini-client.ts
+++ b/packages/cli/src/core/gemini-client.ts
@@ -15,31 +15,17 @@ import {
Content,
} from '@google/genai';
import { CoreSystemPrompt } from './prompts.js';
-import {
- type ToolCallEvent,
- type ToolCallConfirmationDetails,
- ToolCallStatus,
-} from '../ui/types.js';
import process from 'node:process';
import { toolRegistry } from '../tools/tool-registry.js';
-import { ToolResult } from '../tools/tools.js';
import { getFolderStructure } from '../utils/getFolderStructure.js';
import { GeminiEventType, GeminiStream } from './gemini-stream.js';
import { Config } from '../config/config.js';
-
-type ToolExecutionOutcome = {
- callId: string;
- name: string;
- args: Record<string, never>;
- result?: ToolResult;
- error?: Error;
- confirmationDetails?: ToolCallConfirmationDetails;
-};
+import { Turn } from './turn.js';
export class GeminiClient {
private config: Config;
private ai: GoogleGenAI;
- private defaultHyperParameters: GenerateContentConfig = {
+ private generateContentConfig: GenerateContentConfig = {
temperature: 0,
topP: 1,
};
@@ -50,14 +36,9 @@ export class GeminiClient {
this.ai = new GoogleGenAI({ apiKey: config.getApiKey() });
}
- async startChat(): Promise<Chat> {
- const tools = toolRegistry.getToolSchemas();
- const model = this.config.getModel();
-
- // --- Get environmental information ---
+ private async getEnvironment(): Promise<Part> {
const cwd = process.cwd();
const today = new Date().toLocaleDateString(undefined, {
- // Use locale-aware date formatting
weekday: 'long',
year: 'numeric',
month: 'long',
@@ -65,33 +46,37 @@ export class GeminiClient {
});
const platform = process.platform;
- // --- Format information into a conversational multi-line string ---
const folderStructure = await getFolderStructure(cwd);
- // --- End folder structure formatting ---)
- const initialContextText = `
-Okay, just setting up the context for our chat.
-Today is ${today}.
-My operating system is: ${platform}
-I'm currently working in the directory: ${cwd}
-${folderStructure}
- `.trim();
- const initialContextPart: Part = { text: initialContextText };
- // --- End environmental information formatting ---
+ const context = `
+ Okay, just setting up the context for our chat.
+ Today is ${today}.
+ My operating system is: ${platform}
+ I'm currently working in the directory: ${cwd}
+ ${folderStructure}
+ `.trim();
+
+ return { text: context };
+ }
+
+ async startChat(): Promise<Chat> {
+ const envPart = await this.getEnvironment();
+ const model = this.config.getModel();
+ const tools = toolRegistry.getToolSchemas();
try {
const chat = this.ai.chats.create({
model,
config: {
systemInstruction: CoreSystemPrompt,
- ...this.defaultHyperParameters,
+ ...this.generateContentConfig,
tools,
},
history: [
// --- Add the context as a single part in the initial user message ---
{
role: 'user',
- parts: [initialContextPart], // Pass the single Part object in an array
+ parts: [envPart], // Pass the single Part object in an array
},
// --- Add an empty model response to balance the history ---
{
@@ -109,308 +94,113 @@ ${folderStructure}
}
}
- addMessageToHistory(chat: Chat, message: Content): void {
- const history = chat.getHistory();
- history.push(message);
- }
-
async *sendMessageStream(
chat: Chat,
request: PartListUnion,
signal?: AbortSignal,
): GeminiStream {
- let currentMessageToSend: PartListUnion = request;
let turns = 0;
try {
while (turns < this.MAX_TURNS) {
turns++;
- const resultStream = await chat.sendMessageStream({
- message: currentMessageToSend,
- });
- let functionResponseParts: Part[] = [];
- let pendingToolCalls: Array<{
- callId: string;
- name: string;
- args: Record<string, never>;
- }> = [];
- let yieldedTextInTurn = false;
- const chunksForDebug = [];
-
- for await (const chunk of resultStream) {
- chunksForDebug.push(chunk);
- if (signal?.aborted) {
- const abortError = new Error(
- 'Request cancelled by user during stream.',
- );
- abortError.name = 'AbortError';
- throw abortError;
- }
-
- const functionCalls = chunk.functionCalls;
- if (functionCalls && functionCalls.length > 0) {
- for (const call of functionCalls) {
- const callId =
- call.id ??
- `${call.name}-${Date.now()}-${Math.random().toString(16).slice(2)}`;
- const name = call.name || 'undefined_tool_name';
- const args = (call.args || {}) as Record<string, never>;
+ // A turn either yields a text response or returns
+ // function responses to be used in the next turn.
+ // This callsite is responsible to handle the buffered
+ // function responses and use it on the next turn.
+ const turn = new Turn(chat);
+ const resultStream = turn.run(request, signal);
- pendingToolCalls.push({ callId, name, args });
- const evtValue: ToolCallEvent = {
- type: 'tool_call',
- status: ToolCallStatus.Pending,
- callId,
- name,
- args,
- resultDisplay: undefined,
- confirmationDetails: undefined,
- };
- yield {
- type: GeminiEventType.ToolCallInfo,
- value: evtValue,
- };
- }
- } else {
- const text = chunk.text;
- if (text) {
- yieldedTextInTurn = true;
- yield {
- type: GeminiEventType.Content,
- value: text,
- };
- }
- }
+ for await (const event of resultStream) {
+ yield event;
+ }
+ const fnResponses = turn.getFunctionResponses();
+ if (fnResponses.length > 0) {
+ request = fnResponses;
+ continue; // use the responses in the next turn
}
- if (pendingToolCalls.length > 0) {
- const toolPromises: Array<Promise<ToolExecutionOutcome>> =
- pendingToolCalls.map(async (pendingToolCall) => {
- const tool = toolRegistry.getTool(pendingToolCall.name);
-
- if (!tool) {
- // Directly return error outcome if tool not found
- return {
- ...pendingToolCall,
- error: new Error(
- `Tool "${pendingToolCall.name}" not found or is not registered.`,
- ),
- };
- }
-
- try {
- const confirmation = await tool.shouldConfirmExecute(
- pendingToolCall.args,
- );
- if (confirmation) {
- return {
- ...pendingToolCall,
- confirmationDetails: confirmation,
- };
- }
- } catch (error) {
- return {
- ...pendingToolCall,
- error: new Error(
- `Tool failed to check tool confirmation: ${error}`,
- ),
- };
- }
-
- try {
- const result = await tool.execute(pendingToolCall.args);
- return { ...pendingToolCall, result };
- } catch (error) {
- return {
- ...pendingToolCall,
- error: new Error(`Tool failed to execute: ${error}`),
- };
- }
- });
- const toolExecutionOutcomes: ToolExecutionOutcome[] =
- await Promise.all(toolPromises);
-
- for (const executedTool of toolExecutionOutcomes) {
- const { callId, name, args, result, error, confirmationDetails } =
- executedTool;
-
- if (error) {
- const errorMessage = error?.message || String(error);
- yield {
- type: GeminiEventType.Content,
- value: `[Error invoking tool ${name}: ${errorMessage}]`,
- };
- } else if (
- result &&
- typeof result === 'object' &&
- result !== null &&
- 'error' in result
- ) {
- const errorMessage = String(result.error);
- yield {
- type: GeminiEventType.Content,
- value: `[Error executing tool ${name}: ${errorMessage}]`,
- };
- } else {
- const status = confirmationDetails
- ? ToolCallStatus.Confirming
- : ToolCallStatus.Invoked;
- const evtValue: ToolCallEvent = {
- type: 'tool_call',
- status,
- callId,
- name,
- args,
- resultDisplay: result?.returnDisplay,
- confirmationDetails,
- };
- yield {
- type: GeminiEventType.ToolCallInfo,
- value: evtValue,
- };
- }
- }
-
- pendingToolCalls = [];
-
- const waitingOnConfirmations =
- toolExecutionOutcomes.filter(
- (outcome) => outcome.confirmationDetails,
- ).length > 0;
- if (waitingOnConfirmations) {
- // Stop processing content, wait for user.
- // TODO: Kill token processing once API supports signals.
- break;
- }
-
- functionResponseParts = toolExecutionOutcomes.map(
- (executedTool: ToolExecutionOutcome): Part => {
- const { name, result, error } = executedTool;
- const output = { output: result?.llmContent };
- let toolOutcomePayload: Record<string, unknown>;
-
- if (error) {
- const errorMessage = error?.message || String(error);
- toolOutcomePayload = {
- error: `Invocation failed: ${errorMessage}`,
- };
- console.error(
- `[Turn ${turns}] Critical error invoking tool ${name}:`,
- error,
- );
- } else if (
- result &&
- typeof result === 'object' &&
- result !== null &&
- 'error' in result
- ) {
- toolOutcomePayload = output;
- console.warn(
- `[Turn ${turns}] Tool ${name} returned an error structure:`,
- result.error,
- );
- } else {
- toolOutcomePayload = output;
- }
-
- return {
- functionResponse: {
- name,
- id: executedTool.callId,
- response: toolOutcomePayload,
- },
- };
- },
- );
- currentMessageToSend = functionResponseParts;
- } else if (yieldedTextInTurn) {
- const history = chat.getHistory();
- const checkPrompt = `Analyze *only* the content and structure of your immediately preceding response (your last turn in the conversation history). Based *strictly* on that response, determine who should logically speak next: the 'user' or the 'model' (you).
+ const history = chat.getHistory();
+ const checkPrompt = `Analyze *only* the content and structure of your immediately preceding response (your last turn in the conversation history). Based *strictly* on that response, determine who should logically speak next: the 'user' or the 'model' (you).
-**Decision Rules (apply in order):**
+ **Decision Rules (apply in order):**
-1. **Model Continues:** If your last response explicitly states an immediate next action *you* intend to take (e.g., "Next, I will...", "Now I'll process...", "Moving on to analyze...", indicates an intended tool call that didn't execute), OR if the response seems clearly incomplete (cut off mid-thought without a natural conclusion), then the **'model'** should speak next.
-2. **Question to User:** If your last response ends with a direct question specifically addressed *to the user*, then the **'user'** should speak next.
-3. **Waiting for User:** If your last response completed a thought, statement, or task *and* does not meet the criteria for Rule 1 (Model Continues) or Rule 2 (Question to User), it implies a pause expecting user input or reaction. In this case, the **'user'** should speak next.
+ 1. **Model Continues:** If your last response explicitly states an immediate next action *you* intend to take (e.g., "Next, I will...", "Now I'll process...", "Moving on to analyze...", indicates an intended tool call that didn't execute), OR if the response seems clearly incomplete (cut off mid-thought without a natural conclusion), then the **'model'** should speak next.
+ 2. **Question to User:** If your last response ends with a direct question specifically addressed *to the user*, then the **'user'** should speak next.
+ 3. **Waiting for User:** If your last response completed a thought, statement, or task *and* does not meet the criteria for Rule 1 (Model Continues) or Rule 2 (Question to User), it implies a pause expecting user input or reaction. In this case, the **'user'** should speak next.
-**Output Format:**
+ **Output Format:**
-Respond *only* in JSON format according to the following schema. Do not include any text outside the JSON structure.
+ Respond *only* in JSON format according to the following schema. Do not include any text outside the JSON structure.
-\`\`\`json
-{
- "type": "object",
- "properties": {
- "reasoning": {
+ \`\`\`json
+ {
+ "type": "object",
+ "properties": {
+ "reasoning": {
+ "type": "string",
+ "description": "Brief explanation justifying the 'next_speaker' choice based *strictly* on the applicable rule and the content/structure of the preceding turn."
+ },
+ "next_speaker": {
"type": "string",
- "description": "Brief explanation justifying the 'next_speaker' choice based *strictly* on the applicable rule and the content/structure of the preceding turn."
+ "enum": ["user", "model"],
+ "description": "Who should speak next based *only* on the preceding turn and the decision rules."
+ }
},
- "next_speaker": {
- "type": "string",
- "enum": ["user", "model"],
- "description": "Who should speak next based *only* on the preceding turn and the decision rules."
- }
- },
- "required": ["next_speaker", "reasoning"]
-\`\`\`
-}`;
+ "required": ["next_speaker", "reasoning"]
+ \`\`\`
+ }`;
- // Schema Idea
- const responseSchema: SchemaUnion = {
- type: Type.OBJECT,
- properties: {
- reasoning: {
- type: Type.STRING,
- description:
- "Brief explanation justifying the 'next_speaker' choice based *strictly* on the applicable rule and the content/structure of the preceding turn.",
- },
- next_speaker: {
- type: Type.STRING,
- enum: ['user', 'model'], // Enforce the choices
- description:
- 'Who should speak next based *only* on the preceding turn and the decision rules',
- },
+ // Schema Idea
+ const responseSchema: SchemaUnion = {
+ type: Type.OBJECT,
+ properties: {
+ reasoning: {
+ type: Type.STRING,
+ description:
+ "Brief explanation justifying the 'next_speaker' choice based *strictly* on the applicable rule and the content/structure of the preceding turn.",
+ },
+ next_speaker: {
+ type: Type.STRING,
+ enum: ['user', 'model'], // Enforce the choices
+ description:
+ 'Who should speak next based *only* on the preceding turn and the decision rules',
},
- required: ['reasoning', 'next_speaker'],
- };
+ },
+ required: ['reasoning', 'next_speaker'],
+ };
- try {
- // Use the new generateJson method, passing the history and the check prompt
- const parsedResponse = await this.generateJson(
- [
- ...history,
- {
- role: 'user',
- parts: [{ text: checkPrompt }],
- },
- ],
- responseSchema,
- );
+ try {
+ // Use the new generateJson method, passing the history and the check prompt
+ const parsedResponse = await this.generateJson(
+ [
+ ...history,
+ {
+ role: 'user',
+ parts: [{ text: checkPrompt }],
+ },
+ ],
+ responseSchema,
+ );
- // Safely extract the next speaker value
- const nextSpeaker: string | undefined =
- typeof parsedResponse?.next_speaker === 'string'
- ? parsedResponse.next_speaker
- : undefined;
+ // Safely extract the next speaker value
+ const nextSpeaker: string | undefined =
+ typeof parsedResponse?.next_speaker === 'string'
+ ? parsedResponse.next_speaker
+ : undefined;
- if (nextSpeaker === 'model') {
- currentMessageToSend = { text: 'alright' }; // Or potentially a more meaningful continuation prompt
- } else {
- // 'user' should speak next, or value is missing/invalid. End the turn.
- break;
- }
- } catch (error) {
- console.error(
- `[Turn ${turns}] Failed to get or parse next speaker check:`,
- error,
- );
- // If the check fails, assume user should speak next to avoid infinite loops
+ if (nextSpeaker === 'model') {
+ request = { text: 'alright' }; // Or potentially a more meaningful continuation prompt
+ } else {
+ // 'user' should speak next, or value is missing/invalid. End the turn.
break;
}
- } else {
- console.warn(
- `[Turn ${turns}] No text or function calls received from Gemini. Ending interaction.`,
+ } catch (error) {
+ console.error(
+ `[Turn ${turns}] Failed to get or parse next speaker check:`,
+ error,
);
+ // If the check fails, assume user should speak next to avoid infinite loops
break;
}
}
@@ -426,6 +216,8 @@ Respond *only* in JSON format according to the following schema. Do not include
};
}
} catch (error: unknown) {
+ // TODO(jbd): There is so much of packing/unpacking of error types.
+ // Figure out a way to remove the redundant work.
if (error instanceof Error && error.name === 'AbortError') {
console.log('Gemini stream request aborted by user.');
throw error;
@@ -457,7 +249,7 @@ Respond *only* in JSON format according to the following schema. Do not include
const result = await this.ai.models.generateContent({
model,
config: {
- ...this.defaultHyperParameters,
+ ...this.generateContentConfig,
systemInstruction: CoreSystemPrompt,
responseSchema: schema,
responseMimeType: 'application/json',
diff --git a/packages/cli/src/core/turn.ts b/packages/cli/src/core/turn.ts
new file mode 100644
index 00000000..e8c4ef78
--- /dev/null
+++ b/packages/cli/src/core/turn.ts
@@ -0,0 +1,233 @@
+import {
+ Part,
+ Chat,
+ PartListUnion,
+ GenerateContentResponse,
+ FunctionCall,
+} from '@google/genai';
+import {
+ type ToolCallConfirmationDetails,
+ ToolCallStatus,
+ ToolCallEvent,
+} from '../ui/types.js';
+import { ToolResult } from '../tools/tools.js';
+import { toolRegistry } from '../tools/tool-registry.js';
+import { GeminiEventType, GeminiStream } from './gemini-stream.js';
+
+export type ToolExecutionOutcome = {
+ callId: string;
+ name: string;
+ args: Record<string, never>;
+ result?: ToolResult;
+ error?: Error;
+ confirmationDetails?: ToolCallConfirmationDetails;
+};
+
+// TODO(jbd): Move ToolExecutionOutcome to somewhere else?
+
+// A turn manages the agentic loop turn.
+// Turn.run emits throught the turn events that could be used
+// as immediate feedback to the user.
+export class Turn {
+ private readonly chat: Chat;
+ private pendingToolCalls: Array<{
+ callId: string;
+ name: string;
+ args: Record<string, never>;
+ }>;
+ private fnResponses: Part[];
+ private debugResponses: GenerateContentResponse[];
+
+ constructor(chat: Chat) {
+ this.chat = chat;
+ this.pendingToolCalls = [];
+ this.fnResponses = [];
+ this.debugResponses = [];
+ }
+
+ async *run(req: PartListUnion, signal?: AbortSignal): GeminiStream {
+ const responseStream = await this.chat.sendMessageStream({
+ message: req,
+ });
+ for await (const resp of responseStream) {
+ this.debugResponses.push(resp);
+ if (signal?.aborted) {
+ throw this.abortError();
+ }
+ if (resp.text) {
+ yield {
+ type: GeminiEventType.Content,
+ value: resp.text,
+ };
+ continue;
+ }
+ if (!resp.functionCalls) {
+ continue;
+ }
+ for (const fnCall of resp.functionCalls) {
+ for await (const event of this.handlePendingFunctionCall(fnCall)) {
+ yield event;
+ }
+ }
+
+ // Create promises to be able to wait for executions to complete.
+ const toolPromises = this.pendingToolCalls.map(
+ async (pendingToolCall) => {
+ const tool = toolRegistry.getTool(pendingToolCall.name);
+ if (!tool) {
+ return {
+ ...pendingToolCall,
+ error: new Error(
+ `Tool "${pendingToolCall.name}" not found or is not registered.`,
+ ),
+ };
+ }
+ const shouldConfirm = await tool.shouldConfirmExecute(
+ pendingToolCall.args,
+ );
+ if (shouldConfirm) {
+ return {
+ // TODO(jbd): Should confirm isn't confirmation details.
+ ...pendingToolCall,
+ confirmationDetails: shouldConfirm,
+ };
+ }
+ const result = await tool.execute(pendingToolCall.args);
+ return { ...pendingToolCall, result };
+ },
+ );
+ const outcomes = await Promise.all(toolPromises);
+ for await (const event of this.handleToolOutcomes(outcomes)) {
+ yield event;
+ }
+ this.pendingToolCalls = [];
+
+ // TODO(jbd): Make it harder for the caller to ignore the
+ // buffered function responses.
+ this.fnResponses = this.buildFunctionResponses(outcomes);
+ }
+ }
+
+ private async *handlePendingFunctionCall(fnCall: FunctionCall): GeminiStream {
+ const callId =
+ fnCall.id ??
+ `${fnCall.name}-${Date.now()}-${Math.random().toString(16).slice(2)}`;
+ // TODO(jbd): replace with uuid.
+ const name = fnCall.name || 'undefined_tool_name';
+ const args = (fnCall.args || {}) as Record<string, never>;
+
+ this.pendingToolCalls.push({ callId, name, args });
+ const value: ToolCallEvent = {
+ type: 'tool_call',
+ status: ToolCallStatus.Pending,
+ callId,
+ name,
+ args,
+ resultDisplay: undefined,
+ confirmationDetails: undefined,
+ };
+ yield {
+ type: GeminiEventType.ToolCallInfo,
+ value,
+ };
+ }
+
+ private async *handleToolOutcomes(
+ outcomes: ToolExecutionOutcome[],
+ ): GeminiStream {
+ for (const outcome of outcomes) {
+ const { callId, name, args, result, error, confirmationDetails } =
+ outcome;
+ if (error) {
+ // TODO(jbd): Error handling needs a cleanup.
+ const errorMessage = error?.message || String(error);
+ yield {
+ type: GeminiEventType.Content,
+ value: `[Error invoking tool ${name}: ${errorMessage}]`,
+ };
+ return;
+ }
+ if (
+ result &&
+ typeof result === 'object' &&
+ result !== null &&
+ 'error' in result
+ ) {
+ const errorMessage = String(result.error);
+ yield {
+ type: GeminiEventType.Content,
+ value: `[Error executing tool ${name}: ${errorMessage}]`,
+ };
+ return;
+ }
+ const status = confirmationDetails
+ ? ToolCallStatus.Confirming
+ : ToolCallStatus.Invoked;
+ const value: ToolCallEvent = {
+ type: 'tool_call',
+ status,
+ callId,
+ name,
+ args,
+ resultDisplay: result?.returnDisplay,
+ confirmationDetails,
+ };
+ yield {
+ type: GeminiEventType.ToolCallInfo,
+ value,
+ };
+ }
+ }
+
+ private buildFunctionResponses(outcomes: ToolExecutionOutcome[]): Part[] {
+ return outcomes.map((outcome: ToolExecutionOutcome): Part => {
+ const { name, result, error } = outcome;
+ const output = { output: result?.llmContent };
+ let fnResponse: Record<string, unknown>;
+
+ if (error) {
+ const errorMessage = error?.message || String(error);
+ fnResponse = {
+ error: `Invocation failed: ${errorMessage}`,
+ };
+ console.error(`[Turn] Critical error invoking tool ${name}:`, error);
+ } else if (
+ result &&
+ typeof result === 'object' &&
+ result !== null &&
+ 'error' in result
+ ) {
+ fnResponse = output;
+ console.warn(
+ `[Turn] Tool ${name} returned an error structure:`,
+ result.error,
+ );
+ } else {
+ fnResponse = output;
+ }
+
+ return {
+ functionResponse: {
+ name,
+ id: outcome.callId,
+ response: fnResponse,
+ },
+ };
+ });
+ }
+
+ private abortError(): Error {
+ // TODO(jbd): Move it out of this class.
+ const error = new Error('Request cancelled by user during stream.');
+ error.name = 'AbortError';
+ throw error;
+ }
+
+ getFunctionResponses(): Part[] {
+ return this.fnResponses;
+ }
+
+ getDebugResponses(): GenerateContentResponse[] {
+ return this.debugResponses;
+ }
+}