diff options
| -rw-r--r-- | packages/cli/src/ui/hooks/useGeminiStream.ts | 7 | ||||
| -rw-r--r-- | packages/server/src/core/client.ts | 16 | ||||
| -rw-r--r-- | packages/server/src/core/geminiChat.ts | 314 | ||||
| -rw-r--r-- | packages/server/src/core/turn.test.ts | 5 | ||||
| -rw-r--r-- | packages/server/src/core/turn.ts | 4 | ||||
| -rw-r--r-- | packages/server/src/utils/nextSpeakerChecker.test.ts | 7 | ||||
| -rw-r--r-- | packages/server/src/utils/nextSpeakerChecker.ts | 5 |
7 files changed, 339 insertions, 19 deletions
diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index 76d29189..f369d796 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -23,7 +23,7 @@ import { ToolResultDisplay, ToolCallRequestInfo, } from '@gemini-code/server'; -import { type Chat, type PartListUnion, type Part } from '@google/genai'; +import { type PartListUnion, type Part } from '@google/genai'; import { StreamingState, ToolCallStatus, @@ -39,6 +39,7 @@ import { useStateAndRef } from './useStateAndRef.js'; import { UseHistoryManagerReturn } from './useHistoryManager.js'; import { useLogger } from './useLogger.js'; import { useToolScheduler, mapToDisplay } from './useToolScheduler.js'; +import { GeminiChat } from '@gemini-code/server/src/core/geminiChat.js'; enum StreamProcessingStatus { Completed, @@ -63,7 +64,7 @@ export const useGeminiStream = ( ) => { const [initError, setInitError] = useState<string | null>(null); const abortControllerRef = useRef<AbortController | null>(null); - const chatSessionRef = useRef<Chat | null>(null); + const chatSessionRef = useRef<GeminiChat | null>(null); const geminiClientRef = useRef<GeminiClient | null>(null); const [isResponding, setIsResponding] = useState<boolean>(false); const [pendingHistoryItemRef, setPendingHistoryItem] = @@ -235,7 +236,7 @@ export const useGeminiStream = ( const ensureChatSession = useCallback(async (): Promise<{ client: GeminiClient | null; - chat: Chat | null; + chat: GeminiChat | null; }> => { const currentClient = geminiClientRef.current; if (!currentClient) { diff --git a/packages/server/src/core/client.ts b/packages/server/src/core/client.ts index 85850da8..3d5927e3 100644 --- a/packages/server/src/core/client.ts +++ b/packages/server/src/core/client.ts @@ -8,7 +8,6 @@ import { GenerateContentConfig, GoogleGenAI, Part, - Chat, SchemaUnion, PartListUnion, Content, @@ -23,6 +22,7 @@ import { ReadManyFilesTool } from '../tools/read-many-files.js'; import { getResponseText } from '../utils/generateContentResponseUtilities.js'; import { checkNextSpeaker } from '../utils/nextSpeakerChecker.js'; import { reportError } from '../utils/errorReporting.js'; +import { GeminiChat } from './geminiChat.js'; export class GeminiClient { private client: GoogleGenAI; @@ -108,7 +108,7 @@ export class GeminiClient { return initialParts; } - async startChat(): Promise<Chat> { + async startChat(): Promise<GeminiChat> { const envParts = await this.getEnvironment(); const toolDeclarations = this.config .getToolRegistry() @@ -128,15 +128,17 @@ export class GeminiClient { const userMemory = this.config.getUserMemory(); const systemInstruction = getCoreSystemPrompt(userMemory); - return this.client.chats.create({ - model: this.model, - config: { + return new GeminiChat( + this.client, + this.client.models, + this.model, + { systemInstruction, ...this.generateContentConfig, tools, }, history, - }); + ); } catch (error) { await reportError( error, @@ -150,7 +152,7 @@ export class GeminiClient { } async *sendMessageStream( - chat: Chat, + chat: GeminiChat, request: PartListUnion, signal?: AbortSignal, turns: number = this.MAX_TURNS, diff --git a/packages/server/src/core/geminiChat.ts b/packages/server/src/core/geminiChat.ts new file mode 100644 index 00000000..dd5f3b7a --- /dev/null +++ b/packages/server/src/core/geminiChat.ts @@ -0,0 +1,314 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +// DISCLAIMER: This is a copied version of https://github.com/googleapis/js-genai/blob/main/src/chats.ts with the intention of working around a key bug +// where function responses are not treated as "valid" responses: https://b.corp.google.com/issues/420354090 + +import { + GenerateContentResponse, + Content, + Models, + GenerateContentConfig, + SendMessageParameters, + GoogleGenAI, + createUserContent, +} from '@google/genai'; + +/** + * Returns true if the response is valid, false otherwise. + */ +function isValidResponse(response: GenerateContentResponse): boolean { + if (response.candidates === undefined || response.candidates.length === 0) { + return false; + } + const content = response.candidates[0]?.content; + if (content === undefined) { + return false; + } + return isValidContent(content); +} + +function isValidContent(content: Content): boolean { + if (content.parts === undefined || content.parts.length === 0) { + return false; + } + for (const part of content.parts) { + if (part === undefined || Object.keys(part).length === 0) { + return false; + } + if (!part.thought && part.text !== undefined && part.text === '') { + return false; + } + } + return true; +} + +/** + * Validates the history contains the correct roles. + * + * @throws Error if the history does not start with a user turn. + * @throws Error if the history contains an invalid role. + */ +function validateHistory(history: Content[]) { + // Empty history is valid. + if (history.length === 0) { + return; + } + for (const content of history) { + if (content.role !== 'user' && content.role !== 'model') { + throw new Error(`Role must be user or model, but got ${content.role}.`); + } + } +} + +/** + * Extracts the curated (valid) history from a comprehensive history. + * + * @remarks + * The model may sometimes generate invalid or empty contents(e.g., due to safty + * filters or recitation). Extracting valid turns from the history + * ensures that subsequent requests could be accpeted by the model. + */ +function extractCuratedHistory(comprehensiveHistory: Content[]): Content[] { + if (comprehensiveHistory === undefined || comprehensiveHistory.length === 0) { + return []; + } + const curatedHistory: Content[] = []; + const length = comprehensiveHistory.length; + let i = 0; + while (i < length) { + if (comprehensiveHistory[i].role === 'user') { + curatedHistory.push(comprehensiveHistory[i]); + i++; + } else { + const modelOutput: Content[] = []; + let isValid = true; + while (i < length && comprehensiveHistory[i].role === 'model') { + modelOutput.push(comprehensiveHistory[i]); + if (isValid && !isValidContent(comprehensiveHistory[i])) { + isValid = false; + } + i++; + } + if (isValid) { + curatedHistory.push(...modelOutput); + } else { + // Remove the last user input when model content is invalid. + curatedHistory.pop(); + } + } + } + return curatedHistory; +} + +/** + * Chat session that enables sending messages to the model with previous + * conversation context. + * + * @remarks + * The session maintains all the turns between user and model. + */ +export class GeminiChat { + // A promise to represent the current state of the message being sent to the + // model. + private sendPromise: Promise<void> = Promise.resolve(); + + constructor( + private readonly apiClient: GoogleGenAI, + private readonly modelsModule: Models, + private readonly model: string, + private readonly config: GenerateContentConfig = {}, + private history: Content[] = [], + ) { + validateHistory(history); + } + + /** + * Sends a message to the model and returns the response. + * + * @remarks + * This method will wait for the previous message to be processed before + * sending the next message. + * + * @see {@link Chat#sendMessageStream} for streaming method. + * @param params - parameters for sending messages within a chat session. + * @returns The model's response. + * + * @example + * ```ts + * const chat = ai.chats.create({model: 'gemini-2.0-flash'}); + * const response = await chat.sendMessage({ + * message: 'Why is the sky blue?' + * }); + * console.log(response.text); + * ``` + */ + async sendMessage( + params: SendMessageParameters, + ): Promise<GenerateContentResponse> { + await this.sendPromise; + const userContent = createUserContent(params.message); + const responsePromise = this.modelsModule.generateContent({ + model: this.model, + contents: this.getHistory(true).concat(userContent), + config: params.config ?? this.config, + }); + this.sendPromise = (async () => { + const response = await responsePromise; + const outputContent = response.candidates?.[0]?.content; + + // Because the AFC input contains the entire curated chat history in + // addition to the new user input, we need to truncate the AFC history + // to deduplicate the existing chat history. + const fullAutomaticFunctionCallingHistory = + response.automaticFunctionCallingHistory; + const index = this.getHistory(true).length; + + let automaticFunctionCallingHistory: Content[] = []; + if (fullAutomaticFunctionCallingHistory != null) { + automaticFunctionCallingHistory = + fullAutomaticFunctionCallingHistory.slice(index) ?? []; + } + + const modelOutput = outputContent ? [outputContent] : []; + this.recordHistory( + userContent, + modelOutput, + automaticFunctionCallingHistory, + ); + return; + })(); + await this.sendPromise.catch(() => { + // Resets sendPromise to avoid subsequent calls failing + this.sendPromise = Promise.resolve(); + }); + return responsePromise; + } + + /** + * Sends a message to the model and returns the response in chunks. + * + * @remarks + * This method will wait for the previous message to be processed before + * sending the next message. + * + * @see {@link Chat#sendMessage} for non-streaming method. + * @param params - parameters for sending the message. + * @return The model's response. + * + * @example + * ```ts + * const chat = ai.chats.create({model: 'gemini-2.0-flash'}); + * const response = await chat.sendMessageStream({ + * message: 'Why is the sky blue?' + * }); + * for await (const chunk of response) { + * console.log(chunk.text); + * } + * ``` + */ + async sendMessageStream( + params: SendMessageParameters, + ): Promise<AsyncGenerator<GenerateContentResponse>> { + await this.sendPromise; + const userContent = createUserContent(params.message); + const streamResponse = this.modelsModule.generateContentStream({ + model: this.model, + contents: this.getHistory(true).concat(userContent), + config: params.config ?? this.config, + }); + // Resolve the internal tracking of send completion promise - `sendPromise` + // for both success and failure response. The actual failure is still + // propagated by the `await streamResponse`. + this.sendPromise = streamResponse + .then(() => undefined) + .catch(() => undefined); + const response = await streamResponse; + const result = this.processStreamResponse(response, userContent); + return result; + } + + /** + * Returns the chat history. + * + * @remarks + * The history is a list of contents alternating between user and model. + * + * There are two types of history: + * - The `curated history` contains only the valid turns between user and + * model, which will be included in the subsequent requests sent to the model. + * - The `comprehensive history` contains all turns, including invalid or + * empty model outputs, providing a complete record of the history. + * + * The history is updated after receiving the response from the model, + * for streaming response, it means receiving the last chunk of the response. + * + * The `comprehensive history` is returned by default. To get the `curated + * history`, set the `curated` parameter to `true`. + * + * @param curated - whether to return the curated history or the comprehensive + * history. + * @return History contents alternating between user and model for the entire + * chat session. + */ + getHistory(curated: boolean = false): Content[] { + const history = curated + ? extractCuratedHistory(this.history) + : this.history; + // Deep copy the history to avoid mutating the history outside of the + // chat session. + return structuredClone(history); + } + + private async *processStreamResponse( + streamResponse: AsyncGenerator<GenerateContentResponse>, + inputContent: Content, + ) { + const outputContent: Content[] = []; + for await (const chunk of streamResponse) { + if (isValidResponse(chunk)) { + const content = chunk.candidates?.[0]?.content; + if (content !== undefined) { + outputContent.push(content); + } + } + yield chunk; + } + this.recordHistory(inputContent, outputContent); + } + + private recordHistory( + userInput: Content, + modelOutput: Content[], + automaticFunctionCallingHistory?: Content[], + ) { + let outputContents: Content[] = []; + if ( + modelOutput.length > 0 && + modelOutput.every((content) => content.role !== undefined) + ) { + outputContents = modelOutput; + } else { + // Appends an empty content when model returns empty response, so that the + // history is always alternating between user and model. + outputContents.push({ + role: 'model', + parts: [], + } as Content); + } + if ( + automaticFunctionCallingHistory && + automaticFunctionCallingHistory.length > 0 + ) { + this.history.push( + ...extractCuratedHistory(automaticFunctionCallingHistory!), + ); + } else { + this.history.push(userInput); + } + this.history.push(...outputContents); + } +} diff --git a/packages/server/src/core/turn.test.ts b/packages/server/src/core/turn.test.ts index 90d3407f..44bb983f 100644 --- a/packages/server/src/core/turn.test.ts +++ b/packages/server/src/core/turn.test.ts @@ -11,8 +11,9 @@ import { ServerGeminiToolCallRequestEvent, ServerGeminiErrorEvent, } from './turn.js'; -import { Chat, GenerateContentResponse, Part, Content } from '@google/genai'; +import { GenerateContentResponse, Part, Content } from '@google/genai'; import { reportError } from '../utils/errorReporting.js'; +import { GeminiChat } from './geminiChat.js'; const mockSendMessageStream = vi.fn(); const mockGetHistory = vi.fn(); @@ -54,7 +55,7 @@ describe('Turn', () => { sendMessageStream: mockSendMessageStream, getHistory: mockGetHistory, }; - turn = new Turn(mockChatInstance as unknown as Chat); + turn = new Turn(mockChatInstance as unknown as GeminiChat); mockGetHistory.mockReturnValue([]); mockSendMessageStream.mockResolvedValue((async function* () {})()); }); diff --git a/packages/server/src/core/turn.ts b/packages/server/src/core/turn.ts index a02b5eb6..d5c7eb58 100644 --- a/packages/server/src/core/turn.ts +++ b/packages/server/src/core/turn.ts @@ -6,7 +6,6 @@ import { Part, - Chat, PartListUnion, GenerateContentResponse, FunctionCall, @@ -20,6 +19,7 @@ import { import { getResponseText } from '../utils/generateContentResponseUtilities.js'; import { reportError } from '../utils/errorReporting.js'; import { getErrorMessage } from '../utils/errors.js'; +import { GeminiChat } from './geminiChat.js'; // Define a structure for tools passed to the server export interface ServerTool { @@ -113,7 +113,7 @@ export class Turn { }>; private debugResponses: GenerateContentResponse[]; - constructor(private readonly chat: Chat) { + constructor(private readonly chat: GeminiChat) { this.pendingToolCalls = []; this.debugResponses = []; } diff --git a/packages/server/src/utils/nextSpeakerChecker.test.ts b/packages/server/src/utils/nextSpeakerChecker.test.ts index b8d17875..f32227e9 100644 --- a/packages/server/src/utils/nextSpeakerChecker.test.ts +++ b/packages/server/src/utils/nextSpeakerChecker.test.ts @@ -5,10 +5,11 @@ */ import { describe, it, expect, vi, beforeEach, Mock, afterEach } from 'vitest'; -import { Chat, Content } from '@google/genai'; +import { Content } from '@google/genai'; import { GeminiClient } from '../core/client.js'; import { Config } from '../config/config.js'; // Added Config import import { checkNextSpeaker, NextSpeakerResponse } from './nextSpeakerChecker.js'; +import { GeminiChat } from '../core/geminiChat.js'; // Mock GeminiClient and Config constructor vi.mock('../core/client.js'); @@ -39,7 +40,7 @@ vi.mock('@google/genai', async () => { }); describe('checkNextSpeaker', () => { - let mockChat: Chat; + let mockChat: GeminiChat; let mockGeminiClient: GeminiClient; let MockConfig: Mock; @@ -64,7 +65,7 @@ describe('checkNextSpeaker', () => { mockGeminiClient = new GeminiClient(mockConfigInstance); // Simulate chat creation as done in GeminiClient - mockChat = { getHistory: mockGetHistory } as unknown as Chat; + mockChat = { getHistory: mockGetHistory } as unknown as GeminiChat; }); afterEach(() => { diff --git a/packages/server/src/utils/nextSpeakerChecker.ts b/packages/server/src/utils/nextSpeakerChecker.ts index 5eb0c512..3fe813db 100644 --- a/packages/server/src/utils/nextSpeakerChecker.ts +++ b/packages/server/src/utils/nextSpeakerChecker.ts @@ -4,8 +4,9 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { Chat, Content, SchemaUnion, Type } from '@google/genai'; +import { Content, SchemaUnion, Type } from '@google/genai'; import { GeminiClient } from '../core/client.js'; +import { GeminiChat } from '../core/geminiChat.js'; const CHECK_PROMPT = `Analyze *only* the content and structure of your immediately preceding response (your last turn in the conversation history). Based *strictly* on that response, determine who should logically speak next: the 'user' or the 'model' (you). **Decision Rules (apply in order):** @@ -57,7 +58,7 @@ export interface NextSpeakerResponse { } export async function checkNextSpeaker( - chat: Chat, + chat: GeminiChat, geminiClient: GeminiClient, ): Promise<NextSpeakerResponse | null> { // We need to capture the curated history because there are many moments when the model will return invalid turns |
