diff options
Diffstat (limited to 'packages/cli/src/utils/modelCheck.ts')
| -rw-r--r-- | packages/cli/src/utils/modelCheck.ts | 75 |
1 files changed, 75 insertions, 0 deletions
diff --git a/packages/cli/src/utils/modelCheck.ts b/packages/cli/src/utils/modelCheck.ts new file mode 100644 index 00000000..1634656e --- /dev/null +++ b/packages/cli/src/utils/modelCheck.ts @@ -0,0 +1,75 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + DEFAULT_GEMINI_MODEL, + DEFAULT_GEMINI_FLASH_MODEL, +} from '../config/config.js'; + +export interface EffectiveModelCheckResult { + effectiveModel: string; + switched: boolean; + originalModelIfSwitched?: string; +} + +/** + * Checks if the default "pro" model is rate-limited and returns a fallback "flash" + * model if necessary. This function is designed to be silent. + * @param apiKey The API key to use for the check. + * @param currentConfiguredModel The model currently configured in settings. + * @returns An object indicating the model to use, whether a switch occurred, + * and the original model if a switch happened. + */ +export async function getEffectiveModel( + apiKey: string, + currentConfiguredModel: string, +): Promise<EffectiveModelCheckResult> { + if (currentConfiguredModel !== DEFAULT_GEMINI_MODEL) { + // Only check if the user is trying to use the specific pro model we want to fallback from. + return { effectiveModel: currentConfiguredModel, switched: false }; + } + + const modelToTest = DEFAULT_GEMINI_MODEL; + const fallbackModel = DEFAULT_GEMINI_FLASH_MODEL; + const endpoint = `https://generativelanguage.googleapis.com/v1beta/models/${modelToTest}:generateContent?key=${apiKey}`; + const body = JSON.stringify({ + contents: [{ parts: [{ text: 'test' }] }], + generationConfig: { + maxOutputTokens: 1, + temperature: 0, + topK: 1, + thinkingConfig: { thinkingBudget: 0, includeThoughts: false }, + }, + }); + + const controller = new AbortController(); + const timeoutId = setTimeout(() => controller.abort(), 2000); // 500ms timeout for the request + + try { + const response = await fetch(endpoint, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body, + signal: controller.signal, + }); + + clearTimeout(timeoutId); + + if (response.status === 429) { + return { + effectiveModel: fallbackModel, + switched: true, + originalModelIfSwitched: modelToTest, + }; + } + // For any other case (success, other error codes), we stick to the original model. + return { effectiveModel: currentConfiguredModel, switched: false }; + } catch (_error) { + clearTimeout(timeoutId); + // On timeout or any other fetch error, stick to the original model. + return { effectiveModel: currentConfiguredModel, switched: false }; + } +} |
