diff options
| author | Taylor Mullen <[email protected]> | 2025-04-17 18:06:21 -0400 |
|---|---|---|
| committer | N. Taylor Mullen <[email protected]> | 2025-04-17 15:29:34 -0700 |
| commit | cfc697a96d2e716a75e1c3b7f0f34fce81abaf1e (patch) | |
| tree | e06bcba67ca71a874048aa887b17457dbd409bdf /packages/cli/src/core/gemini-client.ts | |
| parent | 7928c1727f0b208ed34850cc89bbb36ea3dd23e5 (diff) | |
Run `npm run format`
- Also updated README.md accordingly.
Part of https://b.corp.google.com/issues/411384603
Diffstat (limited to 'packages/cli/src/core/gemini-client.ts')
| -rw-r--r-- | packages/cli/src/core/gemini-client.ts | 696 |
1 files changed, 398 insertions, 298 deletions
diff --git a/packages/cli/src/core/gemini-client.ts b/packages/cli/src/core/gemini-client.ts index 67812f8e..41cabdb7 100644 --- a/packages/cli/src/core/gemini-client.ts +++ b/packages/cli/src/core/gemini-client.ts @@ -1,13 +1,20 @@ import { - GenerateContentConfig, GoogleGenAI, Part, Chat, - Type, - SchemaUnion, - PartListUnion, - Content + GenerateContentConfig, + GoogleGenAI, + Part, + Chat, + Type, + SchemaUnion, + PartListUnion, + Content, } from '@google/genai'; import { getApiKey } from '../config/env.js'; import { CoreSystemPrompt } from './prompts.js'; -import { type ToolCallEvent, type ToolCallConfirmationDetails, ToolCallStatus } from '../ui/types.js'; +import { + type ToolCallEvent, + type ToolCallConfirmationDetails, + ToolCallStatus, +} from '../ui/types.js'; import process from 'node:process'; import { toolRegistry } from '../tools/tool-registry.js'; import { ToolResult } from '../tools/tools.js'; @@ -15,41 +22,45 @@ import { getFolderStructure } from '../utils/getFolderStructure.js'; import { GeminiEventType, GeminiStream } from './gemini-stream.js'; type ToolExecutionOutcome = { - callId: string; - name: string; - args: Record<string, any>; - result?: ToolResult; - error?: any; - confirmationDetails?: ToolCallConfirmationDetails; + callId: string; + name: string; + args: Record<string, any>; + result?: ToolResult; + error?: any; + confirmationDetails?: ToolCallConfirmationDetails; }; export class GeminiClient { - private ai: GoogleGenAI; - private defaultHyperParameters: GenerateContentConfig = { - temperature: 0, - topP: 1, - }; - private readonly MAX_TURNS = 100; + private ai: GoogleGenAI; + private defaultHyperParameters: GenerateContentConfig = { + temperature: 0, + topP: 1, + }; + private readonly MAX_TURNS = 100; - constructor() { - const apiKey = getApiKey(); - this.ai = new GoogleGenAI({ apiKey }); - } + constructor() { + const apiKey = getApiKey(); + this.ai = new GoogleGenAI({ apiKey }); + } - public async startChat(): Promise<Chat> { - const tools = toolRegistry.getToolSchemas(); + public async startChat(): Promise<Chat> { + const tools = toolRegistry.getToolSchemas(); - // --- Get environmental information --- - const cwd = process.cwd(); - const today = new Date().toLocaleDateString(undefined, { // Use locale-aware date formatting - weekday: 'long', year: 'numeric', month: 'long', day: 'numeric' - }); - const platform = process.platform; + // --- Get environmental information --- + const cwd = process.cwd(); + const today = new Date().toLocaleDateString(undefined, { + // Use locale-aware date formatting + weekday: 'long', + year: 'numeric', + month: 'long', + day: 'numeric', + }); + const platform = process.platform; - // --- Format information into a conversational multi-line string --- - const folderStructure = await getFolderStructure(cwd); - // --- End folder structure formatting ---) - const initialContextText = ` + // --- Format information into a conversational multi-line string --- + const folderStructure = await getFolderStructure(cwd); + // --- End folder structure formatting ---) + const initialContextText = ` Okay, just setting up the context for our chat. Today is ${today}. My operating system is: ${platform} @@ -57,194 +68,258 @@ I'm currently working in the directory: ${cwd} ${folderStructure} `.trim(); - const initialContextPart: Part = { text: initialContextText }; - // --- End environmental information formatting --- + const initialContextPart: Part = { text: initialContextText }; + // --- End environmental information formatting --- - try { - const chat = this.ai.chats.create({ - model: 'gemini-2.0-flash',//'gemini-2.0-flash', - config: { - systemInstruction: CoreSystemPrompt, - ...this.defaultHyperParameters, - tools, - }, - history: [ - // --- Add the context as a single part in the initial user message --- - { - role: "user", - parts: [initialContextPart] // Pass the single Part object in an array - }, - // --- Add an empty model response to balance the history --- - { - role: "model", - parts: [{ text: "Got it. Thanks for the context!" }] // A slightly more conversational model response - } - // --- End history modification --- - ], - }); - return chat; - } catch (error) { - console.error("Error initializing Gemini chat session:", error); - const message = error instanceof Error ? error.message : "Unknown error."; - throw new Error(`Failed to initialize chat: ${message}`); - } + try { + const chat = this.ai.chats.create({ + model: 'gemini-2.0-flash', //'gemini-2.0-flash', + config: { + systemInstruction: CoreSystemPrompt, + ...this.defaultHyperParameters, + tools, + }, + history: [ + // --- Add the context as a single part in the initial user message --- + { + role: 'user', + parts: [initialContextPart], // Pass the single Part object in an array + }, + // --- Add an empty model response to balance the history --- + { + role: 'model', + parts: [{ text: 'Got it. Thanks for the context!' }], // A slightly more conversational model response + }, + // --- End history modification --- + ], + }); + return chat; + } catch (error) { + console.error('Error initializing Gemini chat session:', error); + const message = error instanceof Error ? error.message : 'Unknown error.'; + throw new Error(`Failed to initialize chat: ${message}`); } + } - public addMessageToHistory(chat: Chat, message: Content): void { - const history = chat.getHistory(); - history.push(message); - this.ai.chats - chat - } + public addMessageToHistory(chat: Chat, message: Content): void { + const history = chat.getHistory(); + history.push(message); + this.ai.chats; + chat; + } - public async* sendMessageStream( - chat: Chat, - request: PartListUnion, - signal?: AbortSignal - ): GeminiStream { - let currentMessageToSend: PartListUnion = request; - let turns = 0; + public async *sendMessageStream( + chat: Chat, + request: PartListUnion, + signal?: AbortSignal, + ): GeminiStream { + let currentMessageToSend: PartListUnion = request; + let turns = 0; - try { - while (turns < this.MAX_TURNS) { - turns++; - const resultStream = await chat.sendMessageStream({ message: currentMessageToSend }); - let functionResponseParts: Part[] = []; - let pendingToolCalls: Array<{ callId: string; name: string; args: Record<string, any> }> = []; - let yieldedTextInTurn = false; - const chunksForDebug = []; + try { + while (turns < this.MAX_TURNS) { + turns++; + const resultStream = await chat.sendMessageStream({ + message: currentMessageToSend, + }); + let functionResponseParts: Part[] = []; + let pendingToolCalls: Array<{ + callId: string; + name: string; + args: Record<string, any>; + }> = []; + let yieldedTextInTurn = false; + const chunksForDebug = []; - for await (const chunk of resultStream) { - chunksForDebug.push(chunk); - if (signal?.aborted) { - const abortError = new Error("Request cancelled by user during stream."); - abortError.name = 'AbortError'; - throw abortError; - } + for await (const chunk of resultStream) { + chunksForDebug.push(chunk); + if (signal?.aborted) { + const abortError = new Error( + 'Request cancelled by user during stream.', + ); + abortError.name = 'AbortError'; + throw abortError; + } - const functionCalls = chunk.functionCalls; - if (functionCalls && functionCalls.length > 0) { - for (const call of functionCalls) { - const callId = call.id ?? `${call.name}-${Date.now()}-${Math.random().toString(16).slice(2)}`; - const name = call.name || 'undefined_tool_name'; - const args = (call.args || {}) as Record<string, any>; + const functionCalls = chunk.functionCalls; + if (functionCalls && functionCalls.length > 0) { + for (const call of functionCalls) { + const callId = + call.id ?? + `${call.name}-${Date.now()}-${Math.random().toString(16).slice(2)}`; + const name = call.name || 'undefined_tool_name'; + const args = (call.args || {}) as Record<string, any>; - pendingToolCalls.push({ callId, name, args }); - const evtValue: ToolCallEvent = { - type: 'tool_call', - status: ToolCallStatus.Pending, - callId, - name, - args, - resultDisplay: undefined, - confirmationDetails: undefined, - } - yield { - type: GeminiEventType.ToolCallInfo, - value: evtValue, - }; - } - } else { - const text = chunk.text; - if (text) { - yieldedTextInTurn = true; - yield { - type: GeminiEventType.Content, - value: text, - }; - } - } - } + pendingToolCalls.push({ callId, name, args }); + const evtValue: ToolCallEvent = { + type: 'tool_call', + status: ToolCallStatus.Pending, + callId, + name, + args, + resultDisplay: undefined, + confirmationDetails: undefined, + }; + yield { + type: GeminiEventType.ToolCallInfo, + value: evtValue, + }; + } + } else { + const text = chunk.text; + if (text) { + yieldedTextInTurn = true; + yield { + type: GeminiEventType.Content, + value: text, + }; + } + } + } - if (pendingToolCalls.length > 0) { - const toolPromises: Promise<ToolExecutionOutcome>[] = pendingToolCalls.map(async pendingToolCall => { - const tool = toolRegistry.getTool(pendingToolCall.name); + if (pendingToolCalls.length > 0) { + const toolPromises: Promise<ToolExecutionOutcome>[] = + pendingToolCalls.map(async (pendingToolCall) => { + const tool = toolRegistry.getTool(pendingToolCall.name); - if (!tool) { - // Directly return error outcome if tool not found - return { ...pendingToolCall, error: new Error(`Tool "${pendingToolCall.name}" not found or is not registered.`) }; - } + if (!tool) { + // Directly return error outcome if tool not found + return { + ...pendingToolCall, + error: new Error( + `Tool "${pendingToolCall.name}" not found or is not registered.`, + ), + }; + } - try { - const confirmation = await tool.shouldConfirmExecute(pendingToolCall.args); - if (confirmation) { - return { ...pendingToolCall, confirmationDetails: confirmation }; - } - } catch (error) { - return { ...pendingToolCall, error: new Error(`Tool failed to check tool confirmation: ${error}`) }; - } + try { + const confirmation = await tool.shouldConfirmExecute( + pendingToolCall.args, + ); + if (confirmation) { + return { + ...pendingToolCall, + confirmationDetails: confirmation, + }; + } + } catch (error) { + return { + ...pendingToolCall, + error: new Error( + `Tool failed to check tool confirmation: ${error}`, + ), + }; + } - try { - const result = await tool.execute(pendingToolCall.args); - return { ...pendingToolCall, result }; - } catch (error) { - return { ...pendingToolCall, error: new Error(`Tool failed to execute: ${error}`) }; - } - }); - const toolExecutionOutcomes: ToolExecutionOutcome[] = await Promise.all(toolPromises); + try { + const result = await tool.execute(pendingToolCall.args); + return { ...pendingToolCall, result }; + } catch (error) { + return { + ...pendingToolCall, + error: new Error(`Tool failed to execute: ${error}`), + }; + } + }); + const toolExecutionOutcomes: ToolExecutionOutcome[] = + await Promise.all(toolPromises); - for (const executedTool of toolExecutionOutcomes) { - const { callId, name, args, result, error, confirmationDetails } = executedTool; + for (const executedTool of toolExecutionOutcomes) { + const { callId, name, args, result, error, confirmationDetails } = + executedTool; - if (error) { - const errorMessage = error?.message || String(error); - yield { - type: GeminiEventType.Content, - value: `[Error invoking tool ${name}: ${errorMessage}]`, - }; - } else if (result && typeof result === 'object' && result !== null && 'error' in result) { - const errorMessage = String(result.error); - yield { - type: GeminiEventType.Content, - value: `[Error executing tool ${name}: ${errorMessage}]`, - }; - } else { - const status = confirmationDetails ? ToolCallStatus.Confirming : ToolCallStatus.Invoked; - const evtValue: ToolCallEvent = { type: 'tool_call', status, callId, name, args, resultDisplay: result?.returnDisplay, confirmationDetails } - yield { - type: GeminiEventType.ToolCallInfo, - value: evtValue, - }; - } - } + if (error) { + const errorMessage = error?.message || String(error); + yield { + type: GeminiEventType.Content, + value: `[Error invoking tool ${name}: ${errorMessage}]`, + }; + } else if ( + result && + typeof result === 'object' && + result !== null && + 'error' in result + ) { + const errorMessage = String(result.error); + yield { + type: GeminiEventType.Content, + value: `[Error executing tool ${name}: ${errorMessage}]`, + }; + } else { + const status = confirmationDetails + ? ToolCallStatus.Confirming + : ToolCallStatus.Invoked; + const evtValue: ToolCallEvent = { + type: 'tool_call', + status, + callId, + name, + args, + resultDisplay: result?.returnDisplay, + confirmationDetails, + }; + yield { + type: GeminiEventType.ToolCallInfo, + value: evtValue, + }; + } + } - pendingToolCalls = []; + pendingToolCalls = []; - const waitingOnConfirmations = toolExecutionOutcomes.filter(outcome => outcome.confirmationDetails).length > 0; - if (waitingOnConfirmations) { - // Stop processing content, wait for user. - // TODO: Kill token processing once API supports signals. - break; - } + const waitingOnConfirmations = + toolExecutionOutcomes.filter( + (outcome) => outcome.confirmationDetails, + ).length > 0; + if (waitingOnConfirmations) { + // Stop processing content, wait for user. + // TODO: Kill token processing once API supports signals. + break; + } - functionResponseParts = toolExecutionOutcomes.map((executedTool: ToolExecutionOutcome): Part => { - const { name, result, error } = executedTool; - const output = { "output": result?.llmContent }; - let toolOutcomePayload: any; + functionResponseParts = toolExecutionOutcomes.map( + (executedTool: ToolExecutionOutcome): Part => { + const { name, result, error } = executedTool; + const output = { output: result?.llmContent }; + let toolOutcomePayload: any; - if (error) { - const errorMessage = error?.message || String(error); - toolOutcomePayload = { error: `Invocation failed: ${errorMessage}` }; - console.error(`[Turn ${turns}] Critical error invoking tool ${name}:`, error); - } else if (result && typeof result === 'object' && result !== null && 'error' in result) { - toolOutcomePayload = output; - console.warn(`[Turn ${turns}] Tool ${name} returned an error structure:`, result.error); - } else { - toolOutcomePayload = output; - } + if (error) { + const errorMessage = error?.message || String(error); + toolOutcomePayload = { + error: `Invocation failed: ${errorMessage}`, + }; + console.error( + `[Turn ${turns}] Critical error invoking tool ${name}:`, + error, + ); + } else if ( + result && + typeof result === 'object' && + result !== null && + 'error' in result + ) { + toolOutcomePayload = output; + console.warn( + `[Turn ${turns}] Tool ${name} returned an error structure:`, + result.error, + ); + } else { + toolOutcomePayload = output; + } - return { - functionResponse: { - name: name, - id: executedTool.callId, - response: toolOutcomePayload, - }, - }; - }); - currentMessageToSend = functionResponseParts; - } else if (yieldedTextInTurn) { - const history = chat.getHistory(); - const checkPrompt = `Analyze *only* the content and structure of your immediately preceding response (your last turn in the conversation history). Based *strictly* on that response, determine who should logically speak next: the 'user' or the 'model' (you). + return { + functionResponse: { + name: name, + id: executedTool.callId, + response: toolOutcomePayload, + }, + }; + }, + ); + currentMessageToSend = functionResponseParts; + } else if (yieldedTextInTurn) { + const history = chat.getHistory(); + const checkPrompt = `Analyze *only* the content and structure of your immediately preceding response (your last turn in the conversation history). Based *strictly* on that response, determine who should logically speak next: the 'user' or the 'model' (you). **Decision Rules (apply in order):** @@ -274,110 +349,135 @@ Respond *only* in JSON format according to the following schema. Do not include \`\`\` }`; - // Schema Idea - const responseSchema: SchemaUnion = { - type: Type.OBJECT, - properties: { - reasoning: { - type: Type.STRING, - description: "Brief explanation justifying the 'next_speaker' choice based *strictly* on the applicable rule and the content/structure of the preceding turn." - }, - next_speaker: { - type: Type.STRING, - enum: ['user', 'model'], // Enforce the choices - description: "Who should speak next based *only* on the preceding turn and the decision rules", - }, - }, - required: ['reasoning', 'next_speaker'] - }; - - try { - // Use the new generateJson method, passing the history and the check prompt - const parsedResponse = await this.generateJson([...history, { role: "user", parts: [{ text: checkPrompt }] }], responseSchema); - - // Safely extract the next speaker value - const nextSpeaker: string | undefined = typeof parsedResponse?.next_speaker === 'string' ? parsedResponse.next_speaker : undefined; - - if (nextSpeaker === 'model') { - currentMessageToSend = { text: 'alright' }; // Or potentially a more meaningful continuation prompt - } else { - // 'user' should speak next, or value is missing/invalid. End the turn. - break; - } + // Schema Idea + const responseSchema: SchemaUnion = { + type: Type.OBJECT, + properties: { + reasoning: { + type: Type.STRING, + description: + "Brief explanation justifying the 'next_speaker' choice based *strictly* on the applicable rule and the content/structure of the preceding turn.", + }, + next_speaker: { + type: Type.STRING, + enum: ['user', 'model'], // Enforce the choices + description: + 'Who should speak next based *only* on the preceding turn and the decision rules', + }, + }, + required: ['reasoning', 'next_speaker'], + }; - } catch (error) { - console.error(`[Turn ${turns}] Failed to get or parse next speaker check:`, error); - // If the check fails, assume user should speak next to avoid infinite loops - break; - } - } else { - console.warn(`[Turn ${turns}] No text or function calls received from Gemini. Ending interaction.`); - break; - } - - } + try { + // Use the new generateJson method, passing the history and the check prompt + const parsedResponse = await this.generateJson( + [ + ...history, + { + role: 'user', + parts: [{ text: checkPrompt }], + }, + ], + responseSchema, + ); - if (turns >= this.MAX_TURNS) { - console.warn("sendMessageStream: Reached maximum tool call turns limit."); - yield { - type: GeminiEventType.Content, - value: "\n\n[System Notice: Maximum interaction turns reached. The conversation may be incomplete.]", - }; - } + // Safely extract the next speaker value + const nextSpeaker: string | undefined = + typeof parsedResponse?.next_speaker === 'string' + ? parsedResponse.next_speaker + : undefined; - } catch (error: unknown) { - if (error instanceof Error && error.name === 'AbortError') { - console.log("Gemini stream request aborted by user."); - throw error; + if (nextSpeaker === 'model') { + currentMessageToSend = { text: 'alright' }; // Or potentially a more meaningful continuation prompt } else { - console.error(`Error during Gemini stream or tool interaction:`, error); - const message = error instanceof Error ? error.message : String(error); - yield { - type: GeminiEventType.Content, - value: `\n\n[Error: An unexpected error occurred during the chat: ${message}]`, - }; - throw error; + // 'user' should speak next, or value is missing/invalid. End the turn. + break; } + } catch (error) { + console.error( + `[Turn ${turns}] Failed to get or parse next speaker check:`, + error, + ); + // If the check fails, assume user should speak next to avoid infinite loops + break; + } + } else { + console.warn( + `[Turn ${turns}] No text or function calls received from Gemini. Ending interaction.`, + ); + break; } + } + + if (turns >= this.MAX_TURNS) { + console.warn( + 'sendMessageStream: Reached maximum tool call turns limit.', + ); + yield { + type: GeminiEventType.Content, + value: + '\n\n[System Notice: Maximum interaction turns reached. The conversation may be incomplete.]', + }; + } + } catch (error: unknown) { + if (error instanceof Error && error.name === 'AbortError') { + console.log('Gemini stream request aborted by user.'); + throw error; + } else { + console.error(`Error during Gemini stream or tool interaction:`, error); + const message = error instanceof Error ? error.message : String(error); + yield { + type: GeminiEventType.Content, + value: `\n\n[Error: An unexpected error occurred during the chat: ${message}]`, + }; + throw error; + } } + } - /** - * Generates structured JSON content based on conversational history and a schema. - * @param contents The conversational history (Content array) to provide context. - * @param schema The SchemaUnion defining the desired JSON structure. - * @returns A promise that resolves to the parsed JSON object matching the schema. - * @throws Throws an error if the API call fails or the response is not valid JSON. - */ - public async generateJson(contents: Content[], schema: SchemaUnion): Promise<any> { - try { - const result = await this.ai.models.generateContent({ - model: 'gemini-2.0-flash', // Using flash for potentially faster structured output - config: { - ...this.defaultHyperParameters, - systemInstruction: CoreSystemPrompt, - responseSchema: schema, - responseMimeType: 'application/json', - }, - contents: contents, // Pass the full Content array - }); + /** + * Generates structured JSON content based on conversational history and a schema. + * @param contents The conversational history (Content array) to provide context. + * @param schema The SchemaUnion defining the desired JSON structure. + * @returns A promise that resolves to the parsed JSON object matching the schema. + * @throws Throws an error if the API call fails or the response is not valid JSON. + */ + public async generateJson( + contents: Content[], + schema: SchemaUnion, + ): Promise<any> { + try { + const result = await this.ai.models.generateContent({ + model: 'gemini-2.0-flash', // Using flash for potentially faster structured output + config: { + ...this.defaultHyperParameters, + systemInstruction: CoreSystemPrompt, + responseSchema: schema, + responseMimeType: 'application/json', + }, + contents: contents, // Pass the full Content array + }); - const responseText = result.text; - if (!responseText) { - throw new Error("API returned an empty response."); - } + const responseText = result.text; + if (!responseText) { + throw new Error('API returned an empty response.'); + } - try { - const parsedJson = JSON.parse(responseText); - // TODO: Add schema validation if needed - return parsedJson; - } catch (parseError) { - console.error("Failed to parse JSON response:", responseText); - throw new Error(`Failed to parse API response as JSON: ${parseError instanceof Error ? parseError.message : String(parseError)}`); - } - } catch (error) { - console.error("Error generating JSON content:", error); - const message = error instanceof Error ? error.message : "Unknown API error."; - throw new Error(`Failed to generate JSON content: ${message}`); - } + try { + const parsedJson = JSON.parse(responseText); + // TODO: Add schema validation if needed + return parsedJson; + } catch (parseError) { + console.error('Failed to parse JSON response:', responseText); + throw new Error( + `Failed to parse API response as JSON: ${parseError instanceof Error ? parseError.message : String(parseError)}`, + ); + } + } catch (error) { + console.error('Error generating JSON content:', error); + const message = + error instanceof Error ? error.message : 'Unknown API error.'; + throw new Error(`Failed to generate JSON content: ${message}`); } + } } |
