diff options
Diffstat (limited to 'packages/server/src/core')
| -rw-r--r-- | packages/server/src/core/gemini-client.ts | 18 | ||||
| -rw-r--r-- | packages/server/src/core/turn.ts | 134 |
2 files changed, 110 insertions, 42 deletions
diff --git a/packages/server/src/core/gemini-client.ts b/packages/server/src/core/gemini-client.ts index d78a0559..b9b44534 100644 --- a/packages/server/src/core/gemini-client.ts +++ b/packages/server/src/core/gemini-client.ts @@ -18,15 +18,7 @@ import { import { CoreSystemPrompt } from './prompts.js'; import process from 'node:process'; import { getFolderStructure } from '../utils/getFolderStructure.js'; -import { Turn, ServerTool, GeminiEventType } from './turn.js'; - -// Import the ServerGeminiStreamEvent type -type ServerGeminiStreamEvent = - | { type: GeminiEventType.Content; value: string } - | { - type: GeminiEventType.ToolCallRequest; - value: { callId: string; name: string; args: Record<string, unknown> }; - }; +import { Turn, ServerTool, ServerGeminiStreamEvent } from './turn.js'; export class GeminiClient { private ai: GoogleGenAI; @@ -112,6 +104,14 @@ export class GeminiClient { for await (const event of resultStream) { yield event; } + + const confirmations = turn.getConfirmationDetails(); + if (confirmations.length > 0) { + break; + } + + // What do we do when we have both function responses and confirmations? + const fnResponses = turn.getFunctionResponses(); if (fnResponses.length > 0) { request = fnResponses; diff --git a/packages/server/src/core/turn.ts b/packages/server/src/core/turn.ts index 3d8c8c76..0a1c594c 100644 --- a/packages/server/src/core/turn.ts +++ b/packages/server/src/core/turn.ts @@ -13,7 +13,11 @@ import { FunctionDeclaration, } from '@google/genai'; // Removed UI type imports -import { ToolResult } from '../tools/tools.js'; // Keep ToolResult for now +import { + ToolCallConfirmationDetails, + ToolResult, + ToolResultDisplay, +} from '../tools/tools.js'; // Keep ToolResult for now // Removed gemini-stream import (types defined locally) // --- Types for Server Logic --- @@ -25,7 +29,7 @@ interface ServerToolExecutionOutcome { args: Record<string, unknown>; // Use unknown for broader compatibility result?: ToolResult; error?: Error; - // Confirmation details are handled by CLI, not server logic + confirmationDetails: ToolCallConfirmationDetails | undefined; } // Define a structure for tools passed to the server @@ -34,6 +38,9 @@ export interface ServerTool { schema: FunctionDeclaration; // Schema is needed // The execute method signature might differ slightly or be wrapped execute(params: Record<string, unknown>): Promise<ToolResult>; + shouldConfirmExecute( + params: Record<string, unknown>, + ): Promise<ToolCallConfirmationDetails | false>; // validation and description might be handled differently or passed } @@ -41,17 +48,36 @@ export interface ServerTool { export enum GeminiEventType { Content = 'content', ToolCallRequest = 'tool_call_request', + ToolCallResponse = 'tool_call_response', + ToolCallConfirmation = 'tool_call_confirmation', } -interface ToolCallRequestInfo { +export interface ToolCallRequestInfo { callId: string; name: string; args: Record<string, unknown>; } -type ServerGeminiStreamEvent = +export interface ToolCallResponseInfo { + callId: string; + responsePart: Part; + resultDisplay: ToolResultDisplay | undefined; + error: Error | undefined; +} + +export interface ServerToolCallConfirmationDetails { + request: ToolCallRequestInfo; + details: ToolCallConfirmationDetails; +} + +export type ServerGeminiStreamEvent = | { type: GeminiEventType.Content; value: string } - | { type: GeminiEventType.ToolCallRequest; value: ToolCallRequestInfo }; + | { type: GeminiEventType.ToolCallRequest; value: ToolCallRequestInfo } + | { type: GeminiEventType.ToolCallResponse; value: ToolCallResponseInfo } + | { + type: GeminiEventType.ToolCallConfirmation; + value: ServerToolCallConfirmationDetails; + }; // --- Turn Class (Refactored for Server) --- @@ -65,6 +91,7 @@ export class Turn { args: Record<string, unknown>; // Use unknown }>; private fnResponses: Part[]; + private confirmationDetails: ToolCallConfirmationDetails[]; private debugResponses: GenerateContentResponse[]; constructor(chat: Chat, availableTools: ServerTool[]) { @@ -72,6 +99,7 @@ export class Turn { this.availableTools = new Map(availableTools.map((t) => [t.name, t])); this.pendingToolCalls = []; this.fnResponses = []; + this.confirmationDetails = []; this.debugResponses = []; } @@ -113,19 +141,31 @@ export class Turn { error: new Error( `Tool "${pendingToolCall.name}" not found or not provided to Turn.`, ), + confirmationDetails: undefined, }; } - // No confirmation logic in the server Turn + try { - // TODO: Add validation step if needed (tool.validateParams?) - const result = await tool.execute(pendingToolCall.args); - return { ...pendingToolCall, result }; + const confirmationDetails = await tool.shouldConfirmExecute( + pendingToolCall.args, + ); + if (confirmationDetails) { + return { ...pendingToolCall, confirmationDetails }; + } else { + const result = await tool.execute(pendingToolCall.args); + return { + ...pendingToolCall, + result, + confirmationDetails: undefined, + }; + } } catch (execError: unknown) { return { ...pendingToolCall, error: new Error( `Tool execution failed: ${execError instanceof Error ? execError.message : String(execError)}`, ), + confirmationDetails: undefined, }; } }, @@ -133,9 +173,37 @@ export class Turn { const outcomes = await Promise.all(toolPromises); // Process outcomes and prepare function responses - this.fnResponses = this.buildFunctionResponses(outcomes); this.pendingToolCalls = []; // Clear pending calls for this turn + for (let i = 0; i < outcomes.length; i++) { + const outcome = outcomes[i]; + if (outcome.confirmationDetails) { + this.confirmationDetails.push(outcome.confirmationDetails); + const serverConfirmationetails: ServerToolCallConfirmationDetails = { + request: { + callId: outcome.callId, + name: outcome.name, + args: outcome.args, + }, + details: outcome.confirmationDetails, + }; + yield { + type: GeminiEventType.ToolCallConfirmation, + value: serverConfirmationetails, + }; + } else { + const responsePart = this.buildFunctionResponse(outcome); + this.fnResponses.push(responsePart); + const responseInfo: ToolCallResponseInfo = { + callId: outcome.callId, + responsePart, + resultDisplay: outcome.result?.returnDisplay, + error: outcome.error, + }; + yield { type: GeminiEventType.ToolCallResponse, value: responseInfo }; + } + } + // If there were function responses, the caller (GeminiService) will loop // and call run() again with these responses. // If no function responses, the turn ends here. @@ -160,31 +228,27 @@ export class Turn { } // Builds the Part array expected by the Google GenAI API - private buildFunctionResponses( - outcomes: ServerToolExecutionOutcome[], - ): Part[] { - return outcomes.map((outcome): Part => { - const { name, result, error } = outcome; - let fnResponsePayload: Record<string, unknown>; + private buildFunctionResponse(outcome: ServerToolExecutionOutcome): Part { + const { name, result, error } = outcome; + let fnResponsePayload: Record<string, unknown>; - if (error) { - // Format error for the LLM - const errorMessage = error?.message || String(error); - fnResponsePayload = { error: `Tool execution failed: ${errorMessage}` }; - console.error(`[Server Turn] Error executing tool ${name}:`, error); - } else { - // Pass successful tool result (content meant for LLM) - fnResponsePayload = { output: result?.llmContent ?? '' }; // Default to empty string if no content - } + if (error) { + // Format error for the LLM + const errorMessage = error?.message || String(error); + fnResponsePayload = { error: `Tool execution failed: ${errorMessage}` }; + console.error(`[Server Turn] Error executing tool ${name}:`, error); + } else { + // Pass successful tool result (content meant for LLM) + fnResponsePayload = { output: result?.llmContent ?? '' }; // Default to empty string if no content + } - return { - functionResponse: { - name, - id: outcome.callId, - response: fnResponsePayload, - }, - }; - }); + return { + functionResponse: { + name, + id: outcome.callId, + response: fnResponsePayload, + }, + }; } private abortError(): Error { @@ -193,6 +257,10 @@ export class Turn { return error; // Return instead of throw, let caller handle } + getConfirmationDetails(): ToolCallConfirmationDetails[] { + return this.confirmationDetails; + } + // Allows the service layer to get the responses needed for the next API call getFunctionResponses(): Part[] { return this.fnResponses; |
