diff options
Diffstat (limited to 'packages/cli/src/utils')
| -rw-r--r-- | packages/cli/src/utils/modelCheck.test.ts | 179 | ||||
| -rw-r--r-- | packages/cli/src/utils/modelCheck.ts | 75 |
2 files changed, 254 insertions, 0 deletions
diff --git a/packages/cli/src/utils/modelCheck.test.ts b/packages/cli/src/utils/modelCheck.test.ts new file mode 100644 index 00000000..3b1cded8 --- /dev/null +++ b/packages/cli/src/utils/modelCheck.test.ts @@ -0,0 +1,179 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; +import { + getEffectiveModel, + type EffectiveModelCheckResult, +} from './modelCheck.js'; +import { + DEFAULT_GEMINI_MODEL, + DEFAULT_GEMINI_FLASH_MODEL, +} from '../config/config.js'; + +// Mock global fetch +global.fetch = vi.fn(); + +// Mock AbortController +const mockAbort = vi.fn(); +global.AbortController = vi.fn(() => ({ + signal: { aborted: false }, // Start with not aborted + abort: mockAbort, + // eslint-disable-next-line @typescript-eslint/no-explicit-any +})) as any; + +describe('getEffectiveModel', () => { + const apiKey = 'test-api-key'; + + beforeEach(() => { + vi.useFakeTimers(); + vi.clearAllMocks(); + // Reset signal for each test if AbortController mock is more complex + global.AbortController = vi.fn(() => ({ + signal: { aborted: false }, + abort: mockAbort, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + })) as any; + }); + + afterEach(() => { + vi.restoreAllMocks(); + vi.useRealTimers(); + }); + + describe('when currentConfiguredModel is not DEFAULT_GEMINI_MODEL', () => { + it('should return the currentConfiguredModel and switched: false without fetching', async () => { + const customModel = 'custom-model-name'; + const result = await getEffectiveModel(apiKey, customModel); + expect(result).toEqual({ + effectiveModel: customModel, + switched: false, + }); + expect(fetch).not.toHaveBeenCalled(); + }); + }); + + describe('when currentConfiguredModel is DEFAULT_GEMINI_MODEL', () => { + it('should switch to DEFAULT_GEMINI_FLASH_MODEL if fetch returns 429', async () => { + (fetch as vi.Mock).mockResolvedValueOnce({ + ok: false, + status: 429, + }); + const result: EffectiveModelCheckResult = await getEffectiveModel( + apiKey, + DEFAULT_GEMINI_MODEL, + ); + expect(result).toEqual({ + effectiveModel: DEFAULT_GEMINI_FLASH_MODEL, + switched: true, + originalModelIfSwitched: DEFAULT_GEMINI_MODEL, + }); + expect(fetch).toHaveBeenCalledTimes(1); + expect(fetch).toHaveBeenCalledWith( + `https://generativelanguage.googleapis.com/v1beta/models/${DEFAULT_GEMINI_MODEL}:generateContent?key=${apiKey}`, + expect.any(Object), + ); + }); + + it('should return DEFAULT_GEMINI_MODEL if fetch returns 200', async () => { + (fetch as vi.Mock).mockResolvedValueOnce({ + ok: true, + status: 200, + }); + const result = await getEffectiveModel(apiKey, DEFAULT_GEMINI_MODEL); + expect(result).toEqual({ + effectiveModel: DEFAULT_GEMINI_MODEL, + switched: false, + }); + expect(fetch).toHaveBeenCalledTimes(1); + }); + + it('should return DEFAULT_GEMINI_MODEL if fetch returns a non-429 error status (e.g., 500)', async () => { + (fetch as vi.Mock).mockResolvedValueOnce({ + ok: false, + status: 500, + }); + const result = await getEffectiveModel(apiKey, DEFAULT_GEMINI_MODEL); + expect(result).toEqual({ + effectiveModel: DEFAULT_GEMINI_MODEL, + switched: false, + }); + expect(fetch).toHaveBeenCalledTimes(1); + }); + + it('should return DEFAULT_GEMINI_MODEL if fetch throws a network error', async () => { + (fetch as vi.Mock).mockRejectedValueOnce(new Error('Network error')); + const result = await getEffectiveModel(apiKey, DEFAULT_GEMINI_MODEL); + expect(result).toEqual({ + effectiveModel: DEFAULT_GEMINI_MODEL, + switched: false, + }); + expect(fetch).toHaveBeenCalledTimes(1); + }); + + it('should return DEFAULT_GEMINI_MODEL if fetch times out', async () => { + // Simulate AbortController's signal changing and fetch throwing AbortError + const abortControllerInstance = { + signal: { aborted: false }, // mutable signal + abort: vi.fn(() => { + abortControllerInstance.signal.aborted = true; // Use abortControllerInstance + }), + }; + (global.AbortController as vi.Mock).mockImplementationOnce( + () => abortControllerInstance, + ); + + (fetch as vi.Mock).mockImplementationOnce( + async ({ signal }: { signal: AbortSignal }) => { + // Simulate the timeout advancing and abort being called + vi.advanceTimersByTime(2000); + if (signal.aborted) { + throw new DOMException('Aborted', 'AbortError'); + } + // Should not reach here in a timeout scenario + return { ok: true, status: 200 }; + }, + ); + + const resultPromise = getEffectiveModel(apiKey, DEFAULT_GEMINI_MODEL); + // Ensure timers are advanced to trigger the timeout within getEffectiveModel + await vi.advanceTimersToNextTimerAsync(); // Or advanceTimersByTime(2000) if more precise control is needed + + const result = await resultPromise; + + expect(mockAbort).toHaveBeenCalledTimes(0); // setTimeout calls controller.abort(), not our direct mockAbort + expect(abortControllerInstance.abort).toHaveBeenCalledTimes(1); + expect(result).toEqual({ + effectiveModel: DEFAULT_GEMINI_MODEL, + switched: false, + }); + expect(fetch).toHaveBeenCalledTimes(1); + }); + + it('should correctly pass API key and model in the fetch request', async () => { + (fetch as vi.Mock).mockResolvedValueOnce({ ok: true, status: 200 }); + const specificApiKey = 'specific-key-for-this-test'; + await getEffectiveModel(specificApiKey, DEFAULT_GEMINI_MODEL); + + expect(fetch).toHaveBeenCalledWith( + `https://generativelanguage.googleapis.com/v1beta/models/${DEFAULT_GEMINI_MODEL}:generateContent?key=${specificApiKey}`, + expect.objectContaining({ + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ + contents: [{ parts: [{ text: 'test' }] }], + generationConfig: { + maxOutputTokens: 1, + temperature: 0, + topK: 1, + thinkingConfig: { thinkingBudget: 0, includeThoughts: false }, + }, + }), + }), + ); + }); + }); +}); diff --git a/packages/cli/src/utils/modelCheck.ts b/packages/cli/src/utils/modelCheck.ts new file mode 100644 index 00000000..1634656e --- /dev/null +++ b/packages/cli/src/utils/modelCheck.ts @@ -0,0 +1,75 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + DEFAULT_GEMINI_MODEL, + DEFAULT_GEMINI_FLASH_MODEL, +} from '../config/config.js'; + +export interface EffectiveModelCheckResult { + effectiveModel: string; + switched: boolean; + originalModelIfSwitched?: string; +} + +/** + * Checks if the default "pro" model is rate-limited and returns a fallback "flash" + * model if necessary. This function is designed to be silent. + * @param apiKey The API key to use for the check. + * @param currentConfiguredModel The model currently configured in settings. + * @returns An object indicating the model to use, whether a switch occurred, + * and the original model if a switch happened. + */ +export async function getEffectiveModel( + apiKey: string, + currentConfiguredModel: string, +): Promise<EffectiveModelCheckResult> { + if (currentConfiguredModel !== DEFAULT_GEMINI_MODEL) { + // Only check if the user is trying to use the specific pro model we want to fallback from. + return { effectiveModel: currentConfiguredModel, switched: false }; + } + + const modelToTest = DEFAULT_GEMINI_MODEL; + const fallbackModel = DEFAULT_GEMINI_FLASH_MODEL; + const endpoint = `https://generativelanguage.googleapis.com/v1beta/models/${modelToTest}:generateContent?key=${apiKey}`; + const body = JSON.stringify({ + contents: [{ parts: [{ text: 'test' }] }], + generationConfig: { + maxOutputTokens: 1, + temperature: 0, + topK: 1, + thinkingConfig: { thinkingBudget: 0, includeThoughts: false }, + }, + }); + + const controller = new AbortController(); + const timeoutId = setTimeout(() => controller.abort(), 2000); // 500ms timeout for the request + + try { + const response = await fetch(endpoint, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body, + signal: controller.signal, + }); + + clearTimeout(timeoutId); + + if (response.status === 429) { + return { + effectiveModel: fallbackModel, + switched: true, + originalModelIfSwitched: modelToTest, + }; + } + // For any other case (success, other error codes), we stick to the original model. + return { effectiveModel: currentConfiguredModel, switched: false }; + } catch (_error) { + clearTimeout(timeoutId); + // On timeout or any other fetch error, stick to the original model. + return { effectiveModel: currentConfiguredModel, switched: false }; + } +} |
