diff options
| author | Allen Hutchison <[email protected]> | 2025-06-02 13:55:54 -0700 |
|---|---|---|
| committer | GitHub <[email protected]> | 2025-06-02 13:55:54 -0700 |
| commit | 7f20425c98d5adb5531e6c33ed92975b71b34c90 (patch) | |
| tree | d15a9489520c3ebe9f7b22b998803a0b7fb07a8c /packages/cli/src/utils/modelCheck.test.ts | |
| parent | 59b6267b2f3f5d971c10eeaaf9c0e7e82f10cf02 (diff) | |
feat(cli): add pro model availability check and fallback to flash (#608)
Diffstat (limited to 'packages/cli/src/utils/modelCheck.test.ts')
| -rw-r--r-- | packages/cli/src/utils/modelCheck.test.ts | 179 |
1 files changed, 179 insertions, 0 deletions
diff --git a/packages/cli/src/utils/modelCheck.test.ts b/packages/cli/src/utils/modelCheck.test.ts new file mode 100644 index 00000000..3b1cded8 --- /dev/null +++ b/packages/cli/src/utils/modelCheck.test.ts @@ -0,0 +1,179 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; +import { + getEffectiveModel, + type EffectiveModelCheckResult, +} from './modelCheck.js'; +import { + DEFAULT_GEMINI_MODEL, + DEFAULT_GEMINI_FLASH_MODEL, +} from '../config/config.js'; + +// Mock global fetch +global.fetch = vi.fn(); + +// Mock AbortController +const mockAbort = vi.fn(); +global.AbortController = vi.fn(() => ({ + signal: { aborted: false }, // Start with not aborted + abort: mockAbort, + // eslint-disable-next-line @typescript-eslint/no-explicit-any +})) as any; + +describe('getEffectiveModel', () => { + const apiKey = 'test-api-key'; + + beforeEach(() => { + vi.useFakeTimers(); + vi.clearAllMocks(); + // Reset signal for each test if AbortController mock is more complex + global.AbortController = vi.fn(() => ({ + signal: { aborted: false }, + abort: mockAbort, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + })) as any; + }); + + afterEach(() => { + vi.restoreAllMocks(); + vi.useRealTimers(); + }); + + describe('when currentConfiguredModel is not DEFAULT_GEMINI_MODEL', () => { + it('should return the currentConfiguredModel and switched: false without fetching', async () => { + const customModel = 'custom-model-name'; + const result = await getEffectiveModel(apiKey, customModel); + expect(result).toEqual({ + effectiveModel: customModel, + switched: false, + }); + expect(fetch).not.toHaveBeenCalled(); + }); + }); + + describe('when currentConfiguredModel is DEFAULT_GEMINI_MODEL', () => { + it('should switch to DEFAULT_GEMINI_FLASH_MODEL if fetch returns 429', async () => { + (fetch as vi.Mock).mockResolvedValueOnce({ + ok: false, + status: 429, + }); + const result: EffectiveModelCheckResult = await getEffectiveModel( + apiKey, + DEFAULT_GEMINI_MODEL, + ); + expect(result).toEqual({ + effectiveModel: DEFAULT_GEMINI_FLASH_MODEL, + switched: true, + originalModelIfSwitched: DEFAULT_GEMINI_MODEL, + }); + expect(fetch).toHaveBeenCalledTimes(1); + expect(fetch).toHaveBeenCalledWith( + `https://generativelanguage.googleapis.com/v1beta/models/${DEFAULT_GEMINI_MODEL}:generateContent?key=${apiKey}`, + expect.any(Object), + ); + }); + + it('should return DEFAULT_GEMINI_MODEL if fetch returns 200', async () => { + (fetch as vi.Mock).mockResolvedValueOnce({ + ok: true, + status: 200, + }); + const result = await getEffectiveModel(apiKey, DEFAULT_GEMINI_MODEL); + expect(result).toEqual({ + effectiveModel: DEFAULT_GEMINI_MODEL, + switched: false, + }); + expect(fetch).toHaveBeenCalledTimes(1); + }); + + it('should return DEFAULT_GEMINI_MODEL if fetch returns a non-429 error status (e.g., 500)', async () => { + (fetch as vi.Mock).mockResolvedValueOnce({ + ok: false, + status: 500, + }); + const result = await getEffectiveModel(apiKey, DEFAULT_GEMINI_MODEL); + expect(result).toEqual({ + effectiveModel: DEFAULT_GEMINI_MODEL, + switched: false, + }); + expect(fetch).toHaveBeenCalledTimes(1); + }); + + it('should return DEFAULT_GEMINI_MODEL if fetch throws a network error', async () => { + (fetch as vi.Mock).mockRejectedValueOnce(new Error('Network error')); + const result = await getEffectiveModel(apiKey, DEFAULT_GEMINI_MODEL); + expect(result).toEqual({ + effectiveModel: DEFAULT_GEMINI_MODEL, + switched: false, + }); + expect(fetch).toHaveBeenCalledTimes(1); + }); + + it('should return DEFAULT_GEMINI_MODEL if fetch times out', async () => { + // Simulate AbortController's signal changing and fetch throwing AbortError + const abortControllerInstance = { + signal: { aborted: false }, // mutable signal + abort: vi.fn(() => { + abortControllerInstance.signal.aborted = true; // Use abortControllerInstance + }), + }; + (global.AbortController as vi.Mock).mockImplementationOnce( + () => abortControllerInstance, + ); + + (fetch as vi.Mock).mockImplementationOnce( + async ({ signal }: { signal: AbortSignal }) => { + // Simulate the timeout advancing and abort being called + vi.advanceTimersByTime(2000); + if (signal.aborted) { + throw new DOMException('Aborted', 'AbortError'); + } + // Should not reach here in a timeout scenario + return { ok: true, status: 200 }; + }, + ); + + const resultPromise = getEffectiveModel(apiKey, DEFAULT_GEMINI_MODEL); + // Ensure timers are advanced to trigger the timeout within getEffectiveModel + await vi.advanceTimersToNextTimerAsync(); // Or advanceTimersByTime(2000) if more precise control is needed + + const result = await resultPromise; + + expect(mockAbort).toHaveBeenCalledTimes(0); // setTimeout calls controller.abort(), not our direct mockAbort + expect(abortControllerInstance.abort).toHaveBeenCalledTimes(1); + expect(result).toEqual({ + effectiveModel: DEFAULT_GEMINI_MODEL, + switched: false, + }); + expect(fetch).toHaveBeenCalledTimes(1); + }); + + it('should correctly pass API key and model in the fetch request', async () => { + (fetch as vi.Mock).mockResolvedValueOnce({ ok: true, status: 200 }); + const specificApiKey = 'specific-key-for-this-test'; + await getEffectiveModel(specificApiKey, DEFAULT_GEMINI_MODEL); + + expect(fetch).toHaveBeenCalledWith( + `https://generativelanguage.googleapis.com/v1beta/models/${DEFAULT_GEMINI_MODEL}:generateContent?key=${specificApiKey}`, + expect.objectContaining({ + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ + contents: [{ parts: [{ text: 'test' }] }], + generationConfig: { + maxOutputTokens: 1, + temperature: 0, + topK: 1, + thinkingConfig: { thinkingBudget: 0, includeThoughts: false }, + }, + }), + }), + ); + }); + }); +}); |
