diff options
Diffstat (limited to 'packages/core/src')
| -rw-r--r-- | packages/core/src/config/config.ts | 59 | ||||
| -rw-r--r-- | packages/core/src/config/flashFallback.test.ts | 139 | ||||
| -rw-r--r-- | packages/core/src/core/client.test.ts | 4 | ||||
| -rw-r--r-- | packages/core/src/core/client.ts | 52 | ||||
| -rw-r--r-- | packages/core/src/core/contentGenerator.ts | 6 | ||||
| -rw-r--r-- | packages/core/src/core/geminiChat.test.ts | 9 | ||||
| -rw-r--r-- | packages/core/src/core/geminiChat.ts | 44 | ||||
| -rw-r--r-- | packages/core/src/utils/flashFallback.integration.test.ts | 144 | ||||
| -rw-r--r-- | packages/core/src/utils/retry.test.ts | 199 | ||||
| -rw-r--r-- | packages/core/src/utils/retry.ts | 50 | ||||
| -rw-r--r-- | packages/core/src/utils/testUtils.ts | 87 |
11 files changed, 784 insertions, 9 deletions
diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index a92dd7ba..b266512c 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -35,7 +35,10 @@ import { TelemetryTarget, StartSessionEvent, } from '../telemetry/index.js'; -import { DEFAULT_GEMINI_EMBEDDING_MODEL } from './models.js'; +import { + DEFAULT_GEMINI_EMBEDDING_MODEL, + DEFAULT_GEMINI_FLASH_MODEL, +} from './models.js'; import { ClearcutLogger } from '../telemetry/clearcut-logger/clearcut-logger.js'; export enum ApprovalMode { @@ -85,6 +88,11 @@ export interface SandboxConfig { image: string; } +export type FlashFallbackHandler = ( + currentModel: string, + fallbackModel: string, +) => Promise<boolean>; + export interface ConfigParameters { sessionId: string; embeddingModel?: string; @@ -156,6 +164,8 @@ export class Config { private readonly bugCommand: BugCommandSettings | undefined; private readonly model: string; private readonly extensionContextFilePaths: string[]; + private modelSwitchedDuringSession: boolean = false; + flashFallbackHandler?: FlashFallbackHandler; constructor(params: ConfigParameters) { this.sessionId = params.sessionId; @@ -216,9 +226,24 @@ export class Config { } async refreshAuth(authMethod: AuthType) { + // Check if this is actually a switch to a different auth method + const previousAuthType = this.contentGeneratorConfig?.authType; + const _isAuthMethodSwitch = + previousAuthType && previousAuthType !== authMethod; + + // Always use the original default model when switching auth methods + // This ensures users don't stay on Flash after switching between auth types + // and allows API key users to get proper fallback behavior from getEffectiveModel + const modelToUse = this.model; // Use the original default model + + // Temporarily clear contentGeneratorConfig to prevent getModel() from returning + // the previous session's model (which might be Flash) + this.contentGeneratorConfig = undefined!; + const contentConfig = await createContentGeneratorConfig( - this.getModel(), + modelToUse, authMethod, + this, ); const gc = new GeminiClient(this); @@ -226,6 +251,11 @@ export class Config { this.toolRegistry = await createToolRegistry(this); await gc.initialize(contentConfig); this.contentGeneratorConfig = contentConfig; + + // Reset the session flag since we're explicitly changing auth and using default model + this.modelSwitchedDuringSession = false; + + // Note: In the future, we may want to reset any cached state when switching auth methods } getSessionId(): string { @@ -240,6 +270,28 @@ export class Config { return this.contentGeneratorConfig?.model || this.model; } + setModel(newModel: string): void { + if (this.contentGeneratorConfig) { + this.contentGeneratorConfig.model = newModel; + this.modelSwitchedDuringSession = true; + } + } + + isModelSwitchedDuringSession(): boolean { + return this.modelSwitchedDuringSession; + } + + resetModelToDefault(): void { + if (this.contentGeneratorConfig) { + this.contentGeneratorConfig.model = this.model; // Reset to the original default model + this.modelSwitchedDuringSession = false; + } + } + + setFlashFallbackHandler(handler: FlashFallbackHandler): void { + this.flashFallbackHandler = handler; + } + getEmbeddingModel(): string { return this.embeddingModel; } @@ -445,3 +497,6 @@ export function createToolRegistry(config: Config): Promise<ToolRegistry> { return registry; })(); } + +// Export model constants for use in CLI +export { DEFAULT_GEMINI_FLASH_MODEL }; diff --git a/packages/core/src/config/flashFallback.test.ts b/packages/core/src/config/flashFallback.test.ts new file mode 100644 index 00000000..325cc064 --- /dev/null +++ b/packages/core/src/config/flashFallback.test.ts @@ -0,0 +1,139 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, beforeEach } from 'vitest'; +import { Config } from './config.js'; +import { DEFAULT_GEMINI_MODEL, DEFAULT_GEMINI_FLASH_MODEL } from './models.js'; + +describe('Flash Model Fallback Configuration', () => { + let config: Config; + + beforeEach(() => { + config = new Config({ + sessionId: 'test-session', + targetDir: '/test', + debugMode: false, + cwd: '/test', + model: DEFAULT_GEMINI_MODEL, + }); + + // Initialize contentGeneratorConfig for testing + ( + config as unknown as { contentGeneratorConfig: unknown } + ).contentGeneratorConfig = { + model: DEFAULT_GEMINI_MODEL, + authType: 'oauth-personal', + }; + }); + + describe('setModel', () => { + it('should update the model and mark as switched during session', () => { + expect(config.getModel()).toBe(DEFAULT_GEMINI_MODEL); + expect(config.isModelSwitchedDuringSession()).toBe(false); + + config.setModel(DEFAULT_GEMINI_FLASH_MODEL); + + expect(config.getModel()).toBe(DEFAULT_GEMINI_FLASH_MODEL); + expect(config.isModelSwitchedDuringSession()).toBe(true); + }); + + it('should handle multiple model switches during session', () => { + config.setModel(DEFAULT_GEMINI_FLASH_MODEL); + expect(config.isModelSwitchedDuringSession()).toBe(true); + + config.setModel('gemini-1.5-pro'); + expect(config.getModel()).toBe('gemini-1.5-pro'); + expect(config.isModelSwitchedDuringSession()).toBe(true); + }); + + it('should only mark as switched if contentGeneratorConfig exists', () => { + // Create config without initializing contentGeneratorConfig + const newConfig = new Config({ + sessionId: 'test-session-2', + targetDir: '/test', + debugMode: false, + cwd: '/test', + model: DEFAULT_GEMINI_MODEL, + }); + + // Should not crash when contentGeneratorConfig is undefined + newConfig.setModel(DEFAULT_GEMINI_FLASH_MODEL); + expect(newConfig.isModelSwitchedDuringSession()).toBe(false); + }); + }); + + describe('getModel', () => { + it('should return contentGeneratorConfig model if available', () => { + // Simulate initialized content generator config + config.setModel(DEFAULT_GEMINI_FLASH_MODEL); + expect(config.getModel()).toBe(DEFAULT_GEMINI_FLASH_MODEL); + }); + + it('should fallback to initial model if contentGeneratorConfig is not available', () => { + // Test with fresh config where contentGeneratorConfig might not be set + const newConfig = new Config({ + sessionId: 'test-session-2', + targetDir: '/test', + debugMode: false, + cwd: '/test', + model: 'custom-model', + }); + + expect(newConfig.getModel()).toBe('custom-model'); + }); + }); + + describe('isModelSwitchedDuringSession', () => { + it('should start as false for new session', () => { + expect(config.isModelSwitchedDuringSession()).toBe(false); + }); + + it('should remain false if no model switch occurs', () => { + // Perform other operations that don't involve model switching + expect(config.isModelSwitchedDuringSession()).toBe(false); + }); + + it('should persist switched state throughout session', () => { + config.setModel(DEFAULT_GEMINI_FLASH_MODEL); + expect(config.isModelSwitchedDuringSession()).toBe(true); + + // Should remain true even after getting model + config.getModel(); + expect(config.isModelSwitchedDuringSession()).toBe(true); + }); + }); + + describe('resetModelToDefault', () => { + it('should reset model to default and clear session switch flag', () => { + // Switch to Flash first + config.setModel(DEFAULT_GEMINI_FLASH_MODEL); + expect(config.getModel()).toBe(DEFAULT_GEMINI_FLASH_MODEL); + expect(config.isModelSwitchedDuringSession()).toBe(true); + + // Reset to default + config.resetModelToDefault(); + + // Should be back to default with flag cleared + expect(config.getModel()).toBe(DEFAULT_GEMINI_MODEL); + expect(config.isModelSwitchedDuringSession()).toBe(false); + }); + + it('should handle case where contentGeneratorConfig is not initialized', () => { + // Create config without initializing contentGeneratorConfig + const newConfig = new Config({ + sessionId: 'test-session-2', + targetDir: '/test', + debugMode: false, + cwd: '/test', + model: DEFAULT_GEMINI_MODEL, + }); + + // Should not crash when contentGeneratorConfig is undefined + expect(() => newConfig.resetModelToDefault()).not.toThrow(); + expect(newConfig.isModelSwitchedDuringSession()).toBe(false); + }); + }); +}); diff --git a/packages/core/src/core/client.test.ts b/packages/core/src/core/client.test.ts index 8e9bae18..924b4097 100644 --- a/packages/core/src/core/client.test.ts +++ b/packages/core/src/core/client.test.ts @@ -20,6 +20,7 @@ import { Turn } from './turn.js'; import { getCoreSystemPrompt } from './prompts.js'; import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js'; import { FileDiscoveryService } from '../services/fileDiscoveryService.js'; +import { setSimulate429 } from '../utils/testUtils.js'; // --- Mocks --- const mockChatCreateFn = vi.fn(); @@ -68,6 +69,9 @@ describe('Gemini Client (client.ts)', () => { beforeEach(async () => { vi.resetAllMocks(); + // Disable 429 simulation for tests + setSimulate429(false); + // Set up the mock for GoogleGenAI constructor and its methods const MockedGoogleGenAI = vi.mocked(GoogleGenAI); MockedGoogleGenAI.mockImplementation(() => { diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts index 36acb3e8..eb94baed 100644 --- a/packages/core/src/core/client.ts +++ b/packages/core/src/core/client.ts @@ -38,6 +38,7 @@ import { } from './contentGenerator.js'; import { ProxyAgent, setGlobalDispatcher } from 'undici'; import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js'; +import { AuthType } from './contentGenerator.js'; function isThinkingSupported(model: string) { if (model.startsWith('gemini-2.5')) return true; @@ -276,7 +277,11 @@ export class GeminiClient { contents, }); - const result = await retryWithBackoff(apiCall); + const result = await retryWithBackoff(apiCall, { + onPersistent429: async (authType?: string) => + await this.handleFlashFallback(authType), + authType: this.config.getContentGeneratorConfig()?.authType, + }); const text = getResponseText(result); if (!text) { @@ -360,7 +365,11 @@ export class GeminiClient { contents, }); - const result = await retryWithBackoff(apiCall); + const result = await retryWithBackoff(apiCall, { + onPersistent429: async (authType?: string) => + await this.handleFlashFallback(authType), + authType: this.config.getContentGeneratorConfig()?.authType, + }); return result; } catch (error: unknown) { if (abortSignal.aborted) { @@ -489,4 +498,43 @@ export class GeminiClient { } : null; } + + /** + * Handles fallback to Flash model when persistent 429 errors occur for OAuth users. + * Uses a fallback handler if provided by the config, otherwise returns null. + */ + private async handleFlashFallback(authType?: string): Promise<string | null> { + // Only handle fallback for OAuth users + if ( + authType !== AuthType.LOGIN_WITH_GOOGLE_PERSONAL && + authType !== AuthType.LOGIN_WITH_GOOGLE_ENTERPRISE + ) { + return null; + } + + const currentModel = this.model; + const fallbackModel = DEFAULT_GEMINI_FLASH_MODEL; + + // Don't fallback if already using Flash model + if (currentModel === fallbackModel) { + return null; + } + + // Check if config has a fallback handler (set by CLI package) + const fallbackHandler = this.config.flashFallbackHandler; + if (typeof fallbackHandler === 'function') { + try { + const accepted = await fallbackHandler(currentModel, fallbackModel); + if (accepted) { + this.config.setModel(fallbackModel); + this.model = fallbackModel; + return fallbackModel; + } + } catch (error) { + console.warn('Flash fallback handler failed:', error); + } + } + + return null; + } } diff --git a/packages/core/src/core/contentGenerator.ts b/packages/core/src/core/contentGenerator.ts index 058e69a6..c708dad4 100644 --- a/packages/core/src/core/contentGenerator.ts +++ b/packages/core/src/core/contentGenerator.ts @@ -51,14 +51,18 @@ export type ContentGeneratorConfig = { export async function createContentGeneratorConfig( model: string | undefined, authType: AuthType | undefined, + config?: { getModel?: () => string }, ): Promise<ContentGeneratorConfig> { const geminiApiKey = process.env.GEMINI_API_KEY; const googleApiKey = process.env.GOOGLE_API_KEY; const googleCloudProject = process.env.GOOGLE_CLOUD_PROJECT; const googleCloudLocation = process.env.GOOGLE_CLOUD_LOCATION; + // Use runtime model from config if available, otherwise fallback to parameter or default + const effectiveModel = config?.getModel?.() || model || DEFAULT_GEMINI_MODEL; + const contentGeneratorConfig: ContentGeneratorConfig = { - model: model || DEFAULT_GEMINI_MODEL, + model: effectiveModel, authType, }; diff --git a/packages/core/src/core/geminiChat.test.ts b/packages/core/src/core/geminiChat.test.ts index 0b1ed339..18b3729e 100644 --- a/packages/core/src/core/geminiChat.test.ts +++ b/packages/core/src/core/geminiChat.test.ts @@ -14,6 +14,7 @@ import { } from '@google/genai'; import { GeminiChat } from './geminiChat.js'; import { Config } from '../config/config.js'; +import { setSimulate429 } from '../utils/testUtils.js'; // Mocks const mockModelsModule = { @@ -29,6 +30,12 @@ const mockConfig = { getTelemetryLogPromptsEnabled: () => true, getUsageStatisticsEnabled: () => true, getDebugMode: () => false, + getContentGeneratorConfig: () => ({ + authType: 'oauth-personal', + model: 'test-model', + }), + setModel: vi.fn(), + flashFallbackHandler: undefined, } as unknown as Config; describe('GeminiChat', () => { @@ -38,6 +45,8 @@ describe('GeminiChat', () => { beforeEach(() => { vi.clearAllMocks(); + // Disable 429 simulation for tests + setSimulate429(false); // Reset history for each test by creating a new instance chat = new GeminiChat(mockConfig, mockModelsModule, model, config, []); }); diff --git a/packages/core/src/core/geminiChat.ts b/packages/core/src/core/geminiChat.ts index e08aaf86..4db13852 100644 --- a/packages/core/src/core/geminiChat.ts +++ b/packages/core/src/core/geminiChat.ts @@ -18,7 +18,7 @@ import { } from '@google/genai'; import { retryWithBackoff } from '../utils/retry.js'; import { isFunctionResponse } from '../utils/messageInspectors.js'; -import { ContentGenerator } from './contentGenerator.js'; +import { ContentGenerator, AuthType } from './contentGenerator.js'; import { Config } from '../config/config.js'; import { logApiRequest, @@ -34,6 +34,7 @@ import { ApiRequestEvent, ApiResponseEvent, } from '../telemetry/types.js'; +import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js'; /** * Returns true if the response is valid, false otherwise. @@ -182,6 +183,44 @@ export class GeminiChat { } /** + * Handles fallback to Flash model when persistent 429 errors occur for OAuth users. + * Uses a fallback handler if provided by the config, otherwise returns null. + */ + private async handleFlashFallback(authType?: string): Promise<string | null> { + // Only handle fallback for OAuth users + if ( + authType !== AuthType.LOGIN_WITH_GOOGLE_PERSONAL && + authType !== AuthType.LOGIN_WITH_GOOGLE_ENTERPRISE + ) { + return null; + } + + const currentModel = this.model; + const fallbackModel = DEFAULT_GEMINI_FLASH_MODEL; + + // Don't fallback if already using Flash model + if (currentModel === fallbackModel) { + return null; + } + + // Check if config has a fallback handler (set by CLI package) + const fallbackHandler = this.config.flashFallbackHandler; + if (typeof fallbackHandler === 'function') { + try { + const accepted = await fallbackHandler(currentModel, fallbackModel); + if (accepted) { + this.config.setModel(fallbackModel); + return fallbackModel; + } + } catch (error) { + console.warn('Flash fallback handler failed:', error); + } + } + + return null; + } + + /** * Sends a message to the model and returns the response. * * @remarks @@ -315,6 +354,9 @@ export class GeminiChat { } return false; // Don't retry other errors by default }, + onPersistent429: async (authType?: string) => + await this.handleFlashFallback(authType), + authType: this.config.getContentGeneratorConfig()?.authType, }); // Resolve the internal tracking of send completion promise - `sendPromise` diff --git a/packages/core/src/utils/flashFallback.integration.test.ts b/packages/core/src/utils/flashFallback.integration.test.ts new file mode 100644 index 00000000..21c40296 --- /dev/null +++ b/packages/core/src/utils/flashFallback.integration.test.ts @@ -0,0 +1,144 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, beforeEach, vi } from 'vitest'; +import { Config } from '../config/config.js'; +import { + setSimulate429, + disableSimulationAfterFallback, + shouldSimulate429, + createSimulated429Error, + resetRequestCounter, +} from './testUtils.js'; +import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js'; +import { retryWithBackoff } from './retry.js'; +import { AuthType } from '../core/contentGenerator.js'; + +describe('Flash Fallback Integration', () => { + let config: Config; + + beforeEach(() => { + config = new Config({ + sessionId: 'test-session', + targetDir: '/test', + debugMode: false, + cwd: '/test', + model: 'gemini-2.5-pro', + }); + + // Reset simulation state for each test + setSimulate429(false); + resetRequestCounter(); + }); + + it('should automatically accept fallback', async () => { + // Set up a minimal flash fallback handler for testing + const flashFallbackHandler = async (): Promise<boolean> => true; + + config.setFlashFallbackHandler(flashFallbackHandler); + + // Call the handler directly to test + const result = await config.flashFallbackHandler!( + 'gemini-2.5-pro', + DEFAULT_GEMINI_FLASH_MODEL, + ); + + // Verify it automatically accepts + expect(result).toBe(true); + }); + + it('should trigger fallback after 3 consecutive 429 errors for OAuth users', async () => { + let fallbackCalled = false; + let fallbackModel = ''; + + // Mock function that simulates exactly 3 429 errors, then succeeds after fallback + const mockApiCall = vi + .fn() + .mockRejectedValueOnce(createSimulated429Error()) + .mockRejectedValueOnce(createSimulated429Error()) + .mockRejectedValueOnce(createSimulated429Error()) + .mockResolvedValueOnce('success after fallback'); + + // Mock fallback handler + const mockFallbackHandler = vi.fn(async (_authType?: string) => { + fallbackCalled = true; + fallbackModel = DEFAULT_GEMINI_FLASH_MODEL; + return fallbackModel; + }); + + // Test with OAuth personal auth type, with maxAttempts = 3 to ensure fallback triggers + const result = await retryWithBackoff(mockApiCall, { + maxAttempts: 3, + initialDelayMs: 1, + maxDelayMs: 10, + shouldRetry: (error: Error) => { + const status = (error as Error & { status?: number }).status; + return status === 429; + }, + onPersistent429: mockFallbackHandler, + authType: AuthType.LOGIN_WITH_GOOGLE_PERSONAL, + }); + + // Verify fallback was triggered + expect(fallbackCalled).toBe(true); + expect(fallbackModel).toBe(DEFAULT_GEMINI_FLASH_MODEL); + expect(mockFallbackHandler).toHaveBeenCalledWith( + AuthType.LOGIN_WITH_GOOGLE_PERSONAL, + ); + expect(result).toBe('success after fallback'); + // Should have: 3 failures, then fallback triggered, then 1 success after retry reset + expect(mockApiCall).toHaveBeenCalledTimes(4); + }); + + it('should not trigger fallback for API key users', async () => { + let fallbackCalled = false; + + // Mock function that simulates 429 errors + const mockApiCall = vi.fn().mockRejectedValue(createSimulated429Error()); + + // Mock fallback handler + const mockFallbackHandler = vi.fn(async () => { + fallbackCalled = true; + return DEFAULT_GEMINI_FLASH_MODEL; + }); + + // Test with API key auth type - should not trigger fallback + try { + await retryWithBackoff(mockApiCall, { + maxAttempts: 5, + initialDelayMs: 10, + maxDelayMs: 100, + shouldRetry: (error: Error) => { + const status = (error as Error & { status?: number }).status; + return status === 429; + }, + onPersistent429: mockFallbackHandler, + authType: AuthType.USE_GEMINI, // API key auth type + }); + } catch (error) { + // Expected to throw after max attempts + expect((error as Error).message).toContain('Rate limit exceeded'); + } + + // Verify fallback was NOT triggered for API key users + expect(fallbackCalled).toBe(false); + expect(mockFallbackHandler).not.toHaveBeenCalled(); + }); + + it('should properly disable simulation state after fallback', () => { + // Enable simulation + setSimulate429(true); + + // Verify simulation is enabled + expect(shouldSimulate429()).toBe(true); + + // Disable simulation after fallback + disableSimulationAfterFallback(); + + // Verify simulation is now disabled + expect(shouldSimulate429()).toBe(false); + }); +}); diff --git a/packages/core/src/utils/retry.test.ts b/packages/core/src/utils/retry.test.ts index 4c269987..39f62981 100644 --- a/packages/core/src/utils/retry.test.ts +++ b/packages/core/src/utils/retry.test.ts @@ -7,6 +7,7 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; import { retryWithBackoff } from './retry.js'; +import { setSimulate429 } from './testUtils.js'; // Define an interface for the error with a status property interface HttpError extends Error { @@ -42,10 +43,15 @@ class NonRetryableError extends Error { describe('retryWithBackoff', () => { beforeEach(() => { vi.useFakeTimers(); + // Disable 429 simulation for tests + setSimulate429(false); + // Suppress unhandled promise rejection warnings for tests that expect errors + console.warn = vi.fn(); }); afterEach(() => { vi.restoreAllMocks(); + vi.useRealTimers(); }); it('should return the result on the first attempt if successful', async () => { @@ -231,4 +237,197 @@ describe('retryWithBackoff', () => { expect(d).toBeLessThanOrEqual(100 * 1.3); }); }); + + describe('Flash model fallback for OAuth users', () => { + it('should trigger fallback for OAuth personal users after persistent 429 errors', async () => { + const fallbackCallback = vi.fn().mockResolvedValue('gemini-2.5-flash'); + + let fallbackOccurred = false; + const mockFn = vi.fn().mockImplementation(async () => { + if (!fallbackOccurred) { + const error: HttpError = new Error('Rate limit exceeded'); + error.status = 429; + throw error; + } + return 'success'; + }); + + const promise = retryWithBackoff(mockFn, { + maxAttempts: 3, + initialDelayMs: 100, + onPersistent429: async (authType?: string) => { + fallbackOccurred = true; + return await fallbackCallback(authType); + }, + authType: 'oauth-personal', + }); + + // Advance all timers to complete retries + await vi.runAllTimersAsync(); + + // Should succeed after fallback + await expect(promise).resolves.toBe('success'); + + // Verify callback was called with correct auth type + expect(fallbackCallback).toHaveBeenCalledWith('oauth-personal'); + + // Should retry again after fallback + expect(mockFn).toHaveBeenCalledTimes(4); // 3 initial attempts + 1 after fallback + }); + + it('should trigger fallback for OAuth enterprise users after persistent 429 errors', async () => { + const fallbackCallback = vi.fn().mockResolvedValue('gemini-2.5-flash'); + + let fallbackOccurred = false; + const mockFn = vi.fn().mockImplementation(async () => { + if (!fallbackOccurred) { + const error: HttpError = new Error('Rate limit exceeded'); + error.status = 429; + throw error; + } + return 'success'; + }); + + const promise = retryWithBackoff(mockFn, { + maxAttempts: 3, + initialDelayMs: 100, + onPersistent429: async (authType?: string) => { + fallbackOccurred = true; + return await fallbackCallback(authType); + }, + authType: 'oauth-enterprise', + }); + + await vi.runAllTimersAsync(); + + await expect(promise).resolves.toBe('success'); + expect(fallbackCallback).toHaveBeenCalledWith('oauth-enterprise'); + }); + + it('should NOT trigger fallback for API key users', async () => { + const fallbackCallback = vi.fn(); + + const mockFn = vi.fn(async () => { + const error: HttpError = new Error('Rate limit exceeded'); + error.status = 429; + throw error; + }); + + const promise = retryWithBackoff(mockFn, { + maxAttempts: 3, + initialDelayMs: 100, + onPersistent429: fallbackCallback, + authType: 'gemini-api-key', + }); + + // Handle the promise properly to avoid unhandled rejections + const resultPromise = promise.catch((error) => error); + await vi.runAllTimersAsync(); + const result = await resultPromise; + + // Should fail after all retries without fallback + expect(result).toBeInstanceOf(Error); + expect(result.message).toBe('Rate limit exceeded'); + + // Callback should not be called for API key users + expect(fallbackCallback).not.toHaveBeenCalled(); + }); + + it('should reset attempt counter and continue after successful fallback', async () => { + let fallbackCalled = false; + const fallbackCallback = vi.fn().mockImplementation(async () => { + fallbackCalled = true; + return 'gemini-2.5-flash'; + }); + + const mockFn = vi.fn().mockImplementation(async () => { + if (!fallbackCalled) { + const error: HttpError = new Error('Rate limit exceeded'); + error.status = 429; + throw error; + } + return 'success'; + }); + + const promise = retryWithBackoff(mockFn, { + maxAttempts: 3, + initialDelayMs: 100, + onPersistent429: fallbackCallback, + authType: 'oauth-personal', + }); + + await vi.runAllTimersAsync(); + + await expect(promise).resolves.toBe('success'); + expect(fallbackCallback).toHaveBeenCalledOnce(); + }); + + it('should continue with original error if fallback is rejected', async () => { + const fallbackCallback = vi.fn().mockResolvedValue(null); // User rejected fallback + + const mockFn = vi.fn(async () => { + const error: HttpError = new Error('Rate limit exceeded'); + error.status = 429; + throw error; + }); + + const promise = retryWithBackoff(mockFn, { + maxAttempts: 3, + initialDelayMs: 100, + onPersistent429: fallbackCallback, + authType: 'oauth-personal', + }); + + // Handle the promise properly to avoid unhandled rejections + const resultPromise = promise.catch((error) => error); + await vi.runAllTimersAsync(); + const result = await resultPromise; + + // Should fail with original error when fallback is rejected + expect(result).toBeInstanceOf(Error); + expect(result.message).toBe('Rate limit exceeded'); + expect(fallbackCallback).toHaveBeenCalledWith('oauth-personal'); + }); + + it('should handle mixed error types (only count consecutive 429s)', async () => { + const fallbackCallback = vi.fn().mockResolvedValue('gemini-2.5-flash'); + let attempts = 0; + let fallbackOccurred = false; + + const mockFn = vi.fn().mockImplementation(async () => { + attempts++; + if (fallbackOccurred) { + return 'success'; + } + if (attempts === 1) { + // First attempt: 500 error (resets consecutive count) + const error: HttpError = new Error('Server error'); + error.status = 500; + throw error; + } else { + // Remaining attempts: 429 errors + const error: HttpError = new Error('Rate limit exceeded'); + error.status = 429; + throw error; + } + }); + + const promise = retryWithBackoff(mockFn, { + maxAttempts: 5, + initialDelayMs: 100, + onPersistent429: async (authType?: string) => { + fallbackOccurred = true; + return await fallbackCallback(authType); + }, + authType: 'oauth-personal', + }); + + await vi.runAllTimersAsync(); + + await expect(promise).resolves.toBe('success'); + + // Should trigger fallback after 4 consecutive 429s (attempts 2-5) + expect(fallbackCallback).toHaveBeenCalledWith('oauth-personal'); + }); + }); }); diff --git a/packages/core/src/utils/retry.ts b/packages/core/src/utils/retry.ts index 1e7d5bcb..e0fc4ced 100644 --- a/packages/core/src/utils/retry.ts +++ b/packages/core/src/utils/retry.ts @@ -4,11 +4,15 @@ * SPDX-License-Identifier: Apache-2.0 */ +import { AuthType } from '../core/contentGenerator.js'; + export interface RetryOptions { maxAttempts: number; initialDelayMs: number; maxDelayMs: number; shouldRetry: (error: Error) => boolean; + onPersistent429?: (authType?: string) => Promise<string | null>; + authType?: string; } const DEFAULT_RETRY_OPTIONS: RetryOptions = { @@ -59,29 +63,69 @@ export async function retryWithBackoff<T>( fn: () => Promise<T>, options?: Partial<RetryOptions>, ): Promise<T> { - const { maxAttempts, initialDelayMs, maxDelayMs, shouldRetry } = { + const { + maxAttempts, + initialDelayMs, + maxDelayMs, + shouldRetry, + onPersistent429, + authType, + } = { ...DEFAULT_RETRY_OPTIONS, ...options, }; let attempt = 0; let currentDelay = initialDelayMs; + let consecutive429Count = 0; while (attempt < maxAttempts) { attempt++; try { return await fn(); } catch (error) { + const errorStatus = getErrorStatus(error); + + // Track consecutive 429 errors + if (errorStatus === 429) { + consecutive429Count++; + } else { + consecutive429Count = 0; + } + + // Check if we've exhausted retries or shouldn't retry if (attempt >= maxAttempts || !shouldRetry(error as Error)) { + // If we have persistent 429s and a fallback callback for OAuth + if ( + consecutive429Count >= 3 && + onPersistent429 && + (authType === AuthType.LOGIN_WITH_GOOGLE_PERSONAL || + authType === AuthType.LOGIN_WITH_GOOGLE_ENTERPRISE) + ) { + try { + const fallbackModel = await onPersistent429(authType); + if (fallbackModel) { + // Reset attempt counter and try with new model + attempt = 0; + consecutive429Count = 0; + currentDelay = initialDelayMs; + continue; + } + } catch (fallbackError) { + // If fallback fails, continue with original error + console.warn('Fallback to Flash model failed:', fallbackError); + } + } throw error; } - const { delayDurationMs, errorStatus } = getDelayDurationAndStatus(error); + const { delayDurationMs, errorStatus: delayErrorStatus } = + getDelayDurationAndStatus(error); if (delayDurationMs > 0) { // Respect Retry-After header if present and parsed console.warn( - `Attempt ${attempt} failed with status ${errorStatus ?? 'unknown'}. Retrying after explicit delay of ${delayDurationMs}ms...`, + `Attempt ${attempt} failed with status ${delayErrorStatus ?? 'unknown'}. Retrying after explicit delay of ${delayDurationMs}ms...`, error, ); await delay(delayDurationMs); diff --git a/packages/core/src/utils/testUtils.ts b/packages/core/src/utils/testUtils.ts new file mode 100644 index 00000000..a0010b10 --- /dev/null +++ b/packages/core/src/utils/testUtils.ts @@ -0,0 +1,87 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +/** + * Testing utilities for simulating 429 errors in unit tests + */ + +let requestCounter = 0; +let simulate429Enabled = false; +let simulate429AfterRequests = 0; +let simulate429ForAuthType: string | undefined; +let fallbackOccurred = false; + +/** + * Check if we should simulate a 429 error for the current request + */ +export function shouldSimulate429(authType?: string): boolean { + if (!simulate429Enabled || fallbackOccurred) { + return false; + } + + // If auth type filter is set, only simulate for that auth type + if (simulate429ForAuthType && authType !== simulate429ForAuthType) { + return false; + } + + requestCounter++; + + // If afterRequests is set, only simulate after that many requests + if (simulate429AfterRequests > 0) { + return requestCounter > simulate429AfterRequests; + } + + // Otherwise, simulate for every request + return true; +} + +/** + * Reset the request counter (useful for tests) + */ +export function resetRequestCounter(): void { + requestCounter = 0; +} + +/** + * Disable 429 simulation after successful fallback + */ +export function disableSimulationAfterFallback(): void { + fallbackOccurred = true; +} + +/** + * Create a simulated 429 error response + */ +export function createSimulated429Error(): Error { + const error = new Error('Rate limit exceeded (simulated)') as Error & { + status: number; + }; + error.status = 429; + return error; +} + +/** + * Reset simulation state when switching auth methods + */ +export function resetSimulationState(): void { + fallbackOccurred = false; + resetRequestCounter(); +} + +/** + * Enable/disable 429 simulation programmatically (for tests) + */ +export function setSimulate429( + enabled: boolean, + afterRequests = 0, + forAuthType?: string, +): void { + simulate429Enabled = enabled; + simulate429AfterRequests = afterRequests; + simulate429ForAuthType = forAuthType; + fallbackOccurred = false; // Reset fallback state when simulation is re-enabled + resetRequestCounter(); +} |
