diff options
| author | Taylor Mullen <[email protected]> | 2025-05-26 14:17:56 -0700 |
|---|---|---|
| committer | N. Taylor Mullen <[email protected]> | 2025-05-26 14:20:28 -0700 |
| commit | 480549e02ed4ae01c7df2abbd98bb0eb5b23bdd5 (patch) | |
| tree | 9eae14b1d3e498a78b5809154189bfcc93de6274 /packages/server/src/core/geminiChat.ts | |
| parent | 02503a3248d377e13dda5cf8ad5fc39cce365811 (diff) | |
Refactor(chat): Introduce custom Chat class for future modifications
- Copied the `Chat` class from `@google/genai` into `packages/server/src/core/geminiChat.ts`.
- This change is in preparation for future modifications to the chat handling logic.
- Updated relevant files to use the new `GeminiChat` class.
Part of https://github.com/google-gemini/gemini-cli/issues/551
Diffstat (limited to 'packages/server/src/core/geminiChat.ts')
| -rw-r--r-- | packages/server/src/core/geminiChat.ts | 314 |
1 files changed, 314 insertions, 0 deletions
diff --git a/packages/server/src/core/geminiChat.ts b/packages/server/src/core/geminiChat.ts new file mode 100644 index 00000000..dd5f3b7a --- /dev/null +++ b/packages/server/src/core/geminiChat.ts @@ -0,0 +1,314 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +// DISCLAIMER: This is a copied version of https://github.com/googleapis/js-genai/blob/main/src/chats.ts with the intention of working around a key bug +// where function responses are not treated as "valid" responses: https://b.corp.google.com/issues/420354090 + +import { + GenerateContentResponse, + Content, + Models, + GenerateContentConfig, + SendMessageParameters, + GoogleGenAI, + createUserContent, +} from '@google/genai'; + +/** + * Returns true if the response is valid, false otherwise. + */ +function isValidResponse(response: GenerateContentResponse): boolean { + if (response.candidates === undefined || response.candidates.length === 0) { + return false; + } + const content = response.candidates[0]?.content; + if (content === undefined) { + return false; + } + return isValidContent(content); +} + +function isValidContent(content: Content): boolean { + if (content.parts === undefined || content.parts.length === 0) { + return false; + } + for (const part of content.parts) { + if (part === undefined || Object.keys(part).length === 0) { + return false; + } + if (!part.thought && part.text !== undefined && part.text === '') { + return false; + } + } + return true; +} + +/** + * Validates the history contains the correct roles. + * + * @throws Error if the history does not start with a user turn. + * @throws Error if the history contains an invalid role. + */ +function validateHistory(history: Content[]) { + // Empty history is valid. + if (history.length === 0) { + return; + } + for (const content of history) { + if (content.role !== 'user' && content.role !== 'model') { + throw new Error(`Role must be user or model, but got ${content.role}.`); + } + } +} + +/** + * Extracts the curated (valid) history from a comprehensive history. + * + * @remarks + * The model may sometimes generate invalid or empty contents(e.g., due to safty + * filters or recitation). Extracting valid turns from the history + * ensures that subsequent requests could be accpeted by the model. + */ +function extractCuratedHistory(comprehensiveHistory: Content[]): Content[] { + if (comprehensiveHistory === undefined || comprehensiveHistory.length === 0) { + return []; + } + const curatedHistory: Content[] = []; + const length = comprehensiveHistory.length; + let i = 0; + while (i < length) { + if (comprehensiveHistory[i].role === 'user') { + curatedHistory.push(comprehensiveHistory[i]); + i++; + } else { + const modelOutput: Content[] = []; + let isValid = true; + while (i < length && comprehensiveHistory[i].role === 'model') { + modelOutput.push(comprehensiveHistory[i]); + if (isValid && !isValidContent(comprehensiveHistory[i])) { + isValid = false; + } + i++; + } + if (isValid) { + curatedHistory.push(...modelOutput); + } else { + // Remove the last user input when model content is invalid. + curatedHistory.pop(); + } + } + } + return curatedHistory; +} + +/** + * Chat session that enables sending messages to the model with previous + * conversation context. + * + * @remarks + * The session maintains all the turns between user and model. + */ +export class GeminiChat { + // A promise to represent the current state of the message being sent to the + // model. + private sendPromise: Promise<void> = Promise.resolve(); + + constructor( + private readonly apiClient: GoogleGenAI, + private readonly modelsModule: Models, + private readonly model: string, + private readonly config: GenerateContentConfig = {}, + private history: Content[] = [], + ) { + validateHistory(history); + } + + /** + * Sends a message to the model and returns the response. + * + * @remarks + * This method will wait for the previous message to be processed before + * sending the next message. + * + * @see {@link Chat#sendMessageStream} for streaming method. + * @param params - parameters for sending messages within a chat session. + * @returns The model's response. + * + * @example + * ```ts + * const chat = ai.chats.create({model: 'gemini-2.0-flash'}); + * const response = await chat.sendMessage({ + * message: 'Why is the sky blue?' + * }); + * console.log(response.text); + * ``` + */ + async sendMessage( + params: SendMessageParameters, + ): Promise<GenerateContentResponse> { + await this.sendPromise; + const userContent = createUserContent(params.message); + const responsePromise = this.modelsModule.generateContent({ + model: this.model, + contents: this.getHistory(true).concat(userContent), + config: params.config ?? this.config, + }); + this.sendPromise = (async () => { + const response = await responsePromise; + const outputContent = response.candidates?.[0]?.content; + + // Because the AFC input contains the entire curated chat history in + // addition to the new user input, we need to truncate the AFC history + // to deduplicate the existing chat history. + const fullAutomaticFunctionCallingHistory = + response.automaticFunctionCallingHistory; + const index = this.getHistory(true).length; + + let automaticFunctionCallingHistory: Content[] = []; + if (fullAutomaticFunctionCallingHistory != null) { + automaticFunctionCallingHistory = + fullAutomaticFunctionCallingHistory.slice(index) ?? []; + } + + const modelOutput = outputContent ? [outputContent] : []; + this.recordHistory( + userContent, + modelOutput, + automaticFunctionCallingHistory, + ); + return; + })(); + await this.sendPromise.catch(() => { + // Resets sendPromise to avoid subsequent calls failing + this.sendPromise = Promise.resolve(); + }); + return responsePromise; + } + + /** + * Sends a message to the model and returns the response in chunks. + * + * @remarks + * This method will wait for the previous message to be processed before + * sending the next message. + * + * @see {@link Chat#sendMessage} for non-streaming method. + * @param params - parameters for sending the message. + * @return The model's response. + * + * @example + * ```ts + * const chat = ai.chats.create({model: 'gemini-2.0-flash'}); + * const response = await chat.sendMessageStream({ + * message: 'Why is the sky blue?' + * }); + * for await (const chunk of response) { + * console.log(chunk.text); + * } + * ``` + */ + async sendMessageStream( + params: SendMessageParameters, + ): Promise<AsyncGenerator<GenerateContentResponse>> { + await this.sendPromise; + const userContent = createUserContent(params.message); + const streamResponse = this.modelsModule.generateContentStream({ + model: this.model, + contents: this.getHistory(true).concat(userContent), + config: params.config ?? this.config, + }); + // Resolve the internal tracking of send completion promise - `sendPromise` + // for both success and failure response. The actual failure is still + // propagated by the `await streamResponse`. + this.sendPromise = streamResponse + .then(() => undefined) + .catch(() => undefined); + const response = await streamResponse; + const result = this.processStreamResponse(response, userContent); + return result; + } + + /** + * Returns the chat history. + * + * @remarks + * The history is a list of contents alternating between user and model. + * + * There are two types of history: + * - The `curated history` contains only the valid turns between user and + * model, which will be included in the subsequent requests sent to the model. + * - The `comprehensive history` contains all turns, including invalid or + * empty model outputs, providing a complete record of the history. + * + * The history is updated after receiving the response from the model, + * for streaming response, it means receiving the last chunk of the response. + * + * The `comprehensive history` is returned by default. To get the `curated + * history`, set the `curated` parameter to `true`. + * + * @param curated - whether to return the curated history or the comprehensive + * history. + * @return History contents alternating between user and model for the entire + * chat session. + */ + getHistory(curated: boolean = false): Content[] { + const history = curated + ? extractCuratedHistory(this.history) + : this.history; + // Deep copy the history to avoid mutating the history outside of the + // chat session. + return structuredClone(history); + } + + private async *processStreamResponse( + streamResponse: AsyncGenerator<GenerateContentResponse>, + inputContent: Content, + ) { + const outputContent: Content[] = []; + for await (const chunk of streamResponse) { + if (isValidResponse(chunk)) { + const content = chunk.candidates?.[0]?.content; + if (content !== undefined) { + outputContent.push(content); + } + } + yield chunk; + } + this.recordHistory(inputContent, outputContent); + } + + private recordHistory( + userInput: Content, + modelOutput: Content[], + automaticFunctionCallingHistory?: Content[], + ) { + let outputContents: Content[] = []; + if ( + modelOutput.length > 0 && + modelOutput.every((content) => content.role !== undefined) + ) { + outputContents = modelOutput; + } else { + // Appends an empty content when model returns empty response, so that the + // history is always alternating between user and model. + outputContents.push({ + role: 'model', + parts: [], + } as Content); + } + if ( + automaticFunctionCallingHistory && + automaticFunctionCallingHistory.length > 0 + ) { + this.history.push( + ...extractCuratedHistory(automaticFunctionCallingHistory!), + ); + } else { + this.history.push(userInput); + } + this.history.push(...outputContents); + } +} |
