summaryrefslogtreecommitdiff
path: root/packages/core/src
diff options
context:
space:
mode:
authorBryan Morgan <[email protected]>2025-06-24 18:48:55 -0400
committerGitHub <[email protected]>2025-06-24 22:48:55 +0000
commite356949d3fb600abd1a993949300a6c3e0008621 (patch)
treed6c32b08bc47e2f3c2d8f6f27e890c1af3ade480 /packages/core/src
parent4bf18da2b08e145d2f4c91f2331347bf8568aed3 (diff)
[JUNE 25] Permanent failover to Flash model for OAuth users after persistent 429 errors (#1376)
Co-authored-by: Scott Densmore <[email protected]>
Diffstat (limited to 'packages/core/src')
-rw-r--r--packages/core/src/config/config.ts59
-rw-r--r--packages/core/src/config/flashFallback.test.ts139
-rw-r--r--packages/core/src/core/client.test.ts4
-rw-r--r--packages/core/src/core/client.ts52
-rw-r--r--packages/core/src/core/contentGenerator.ts6
-rw-r--r--packages/core/src/core/geminiChat.test.ts9
-rw-r--r--packages/core/src/core/geminiChat.ts44
-rw-r--r--packages/core/src/utils/flashFallback.integration.test.ts144
-rw-r--r--packages/core/src/utils/retry.test.ts199
-rw-r--r--packages/core/src/utils/retry.ts50
-rw-r--r--packages/core/src/utils/testUtils.ts87
11 files changed, 784 insertions, 9 deletions
diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts
index a92dd7ba..b266512c 100644
--- a/packages/core/src/config/config.ts
+++ b/packages/core/src/config/config.ts
@@ -35,7 +35,10 @@ import {
TelemetryTarget,
StartSessionEvent,
} from '../telemetry/index.js';
-import { DEFAULT_GEMINI_EMBEDDING_MODEL } from './models.js';
+import {
+ DEFAULT_GEMINI_EMBEDDING_MODEL,
+ DEFAULT_GEMINI_FLASH_MODEL,
+} from './models.js';
import { ClearcutLogger } from '../telemetry/clearcut-logger/clearcut-logger.js';
export enum ApprovalMode {
@@ -85,6 +88,11 @@ export interface SandboxConfig {
image: string;
}
+export type FlashFallbackHandler = (
+ currentModel: string,
+ fallbackModel: string,
+) => Promise<boolean>;
+
export interface ConfigParameters {
sessionId: string;
embeddingModel?: string;
@@ -156,6 +164,8 @@ export class Config {
private readonly bugCommand: BugCommandSettings | undefined;
private readonly model: string;
private readonly extensionContextFilePaths: string[];
+ private modelSwitchedDuringSession: boolean = false;
+ flashFallbackHandler?: FlashFallbackHandler;
constructor(params: ConfigParameters) {
this.sessionId = params.sessionId;
@@ -216,9 +226,24 @@ export class Config {
}
async refreshAuth(authMethod: AuthType) {
+ // Check if this is actually a switch to a different auth method
+ const previousAuthType = this.contentGeneratorConfig?.authType;
+ const _isAuthMethodSwitch =
+ previousAuthType && previousAuthType !== authMethod;
+
+ // Always use the original default model when switching auth methods
+ // This ensures users don't stay on Flash after switching between auth types
+ // and allows API key users to get proper fallback behavior from getEffectiveModel
+ const modelToUse = this.model; // Use the original default model
+
+ // Temporarily clear contentGeneratorConfig to prevent getModel() from returning
+ // the previous session's model (which might be Flash)
+ this.contentGeneratorConfig = undefined!;
+
const contentConfig = await createContentGeneratorConfig(
- this.getModel(),
+ modelToUse,
authMethod,
+ this,
);
const gc = new GeminiClient(this);
@@ -226,6 +251,11 @@ export class Config {
this.toolRegistry = await createToolRegistry(this);
await gc.initialize(contentConfig);
this.contentGeneratorConfig = contentConfig;
+
+ // Reset the session flag since we're explicitly changing auth and using default model
+ this.modelSwitchedDuringSession = false;
+
+ // Note: In the future, we may want to reset any cached state when switching auth methods
}
getSessionId(): string {
@@ -240,6 +270,28 @@ export class Config {
return this.contentGeneratorConfig?.model || this.model;
}
+ setModel(newModel: string): void {
+ if (this.contentGeneratorConfig) {
+ this.contentGeneratorConfig.model = newModel;
+ this.modelSwitchedDuringSession = true;
+ }
+ }
+
+ isModelSwitchedDuringSession(): boolean {
+ return this.modelSwitchedDuringSession;
+ }
+
+ resetModelToDefault(): void {
+ if (this.contentGeneratorConfig) {
+ this.contentGeneratorConfig.model = this.model; // Reset to the original default model
+ this.modelSwitchedDuringSession = false;
+ }
+ }
+
+ setFlashFallbackHandler(handler: FlashFallbackHandler): void {
+ this.flashFallbackHandler = handler;
+ }
+
getEmbeddingModel(): string {
return this.embeddingModel;
}
@@ -445,3 +497,6 @@ export function createToolRegistry(config: Config): Promise<ToolRegistry> {
return registry;
})();
}
+
+// Export model constants for use in CLI
+export { DEFAULT_GEMINI_FLASH_MODEL };
diff --git a/packages/core/src/config/flashFallback.test.ts b/packages/core/src/config/flashFallback.test.ts
new file mode 100644
index 00000000..325cc064
--- /dev/null
+++ b/packages/core/src/config/flashFallback.test.ts
@@ -0,0 +1,139 @@
+/**
+ * @license
+ * Copyright 2025 Google LLC
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+import { describe, it, expect, beforeEach } from 'vitest';
+import { Config } from './config.js';
+import { DEFAULT_GEMINI_MODEL, DEFAULT_GEMINI_FLASH_MODEL } from './models.js';
+
+describe('Flash Model Fallback Configuration', () => {
+ let config: Config;
+
+ beforeEach(() => {
+ config = new Config({
+ sessionId: 'test-session',
+ targetDir: '/test',
+ debugMode: false,
+ cwd: '/test',
+ model: DEFAULT_GEMINI_MODEL,
+ });
+
+ // Initialize contentGeneratorConfig for testing
+ (
+ config as unknown as { contentGeneratorConfig: unknown }
+ ).contentGeneratorConfig = {
+ model: DEFAULT_GEMINI_MODEL,
+ authType: 'oauth-personal',
+ };
+ });
+
+ describe('setModel', () => {
+ it('should update the model and mark as switched during session', () => {
+ expect(config.getModel()).toBe(DEFAULT_GEMINI_MODEL);
+ expect(config.isModelSwitchedDuringSession()).toBe(false);
+
+ config.setModel(DEFAULT_GEMINI_FLASH_MODEL);
+
+ expect(config.getModel()).toBe(DEFAULT_GEMINI_FLASH_MODEL);
+ expect(config.isModelSwitchedDuringSession()).toBe(true);
+ });
+
+ it('should handle multiple model switches during session', () => {
+ config.setModel(DEFAULT_GEMINI_FLASH_MODEL);
+ expect(config.isModelSwitchedDuringSession()).toBe(true);
+
+ config.setModel('gemini-1.5-pro');
+ expect(config.getModel()).toBe('gemini-1.5-pro');
+ expect(config.isModelSwitchedDuringSession()).toBe(true);
+ });
+
+ it('should only mark as switched if contentGeneratorConfig exists', () => {
+ // Create config without initializing contentGeneratorConfig
+ const newConfig = new Config({
+ sessionId: 'test-session-2',
+ targetDir: '/test',
+ debugMode: false,
+ cwd: '/test',
+ model: DEFAULT_GEMINI_MODEL,
+ });
+
+ // Should not crash when contentGeneratorConfig is undefined
+ newConfig.setModel(DEFAULT_GEMINI_FLASH_MODEL);
+ expect(newConfig.isModelSwitchedDuringSession()).toBe(false);
+ });
+ });
+
+ describe('getModel', () => {
+ it('should return contentGeneratorConfig model if available', () => {
+ // Simulate initialized content generator config
+ config.setModel(DEFAULT_GEMINI_FLASH_MODEL);
+ expect(config.getModel()).toBe(DEFAULT_GEMINI_FLASH_MODEL);
+ });
+
+ it('should fallback to initial model if contentGeneratorConfig is not available', () => {
+ // Test with fresh config where contentGeneratorConfig might not be set
+ const newConfig = new Config({
+ sessionId: 'test-session-2',
+ targetDir: '/test',
+ debugMode: false,
+ cwd: '/test',
+ model: 'custom-model',
+ });
+
+ expect(newConfig.getModel()).toBe('custom-model');
+ });
+ });
+
+ describe('isModelSwitchedDuringSession', () => {
+ it('should start as false for new session', () => {
+ expect(config.isModelSwitchedDuringSession()).toBe(false);
+ });
+
+ it('should remain false if no model switch occurs', () => {
+ // Perform other operations that don't involve model switching
+ expect(config.isModelSwitchedDuringSession()).toBe(false);
+ });
+
+ it('should persist switched state throughout session', () => {
+ config.setModel(DEFAULT_GEMINI_FLASH_MODEL);
+ expect(config.isModelSwitchedDuringSession()).toBe(true);
+
+ // Should remain true even after getting model
+ config.getModel();
+ expect(config.isModelSwitchedDuringSession()).toBe(true);
+ });
+ });
+
+ describe('resetModelToDefault', () => {
+ it('should reset model to default and clear session switch flag', () => {
+ // Switch to Flash first
+ config.setModel(DEFAULT_GEMINI_FLASH_MODEL);
+ expect(config.getModel()).toBe(DEFAULT_GEMINI_FLASH_MODEL);
+ expect(config.isModelSwitchedDuringSession()).toBe(true);
+
+ // Reset to default
+ config.resetModelToDefault();
+
+ // Should be back to default with flag cleared
+ expect(config.getModel()).toBe(DEFAULT_GEMINI_MODEL);
+ expect(config.isModelSwitchedDuringSession()).toBe(false);
+ });
+
+ it('should handle case where contentGeneratorConfig is not initialized', () => {
+ // Create config without initializing contentGeneratorConfig
+ const newConfig = new Config({
+ sessionId: 'test-session-2',
+ targetDir: '/test',
+ debugMode: false,
+ cwd: '/test',
+ model: DEFAULT_GEMINI_MODEL,
+ });
+
+ // Should not crash when contentGeneratorConfig is undefined
+ expect(() => newConfig.resetModelToDefault()).not.toThrow();
+ expect(newConfig.isModelSwitchedDuringSession()).toBe(false);
+ });
+ });
+});
diff --git a/packages/core/src/core/client.test.ts b/packages/core/src/core/client.test.ts
index 8e9bae18..924b4097 100644
--- a/packages/core/src/core/client.test.ts
+++ b/packages/core/src/core/client.test.ts
@@ -20,6 +20,7 @@ import { Turn } from './turn.js';
import { getCoreSystemPrompt } from './prompts.js';
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
import { FileDiscoveryService } from '../services/fileDiscoveryService.js';
+import { setSimulate429 } from '../utils/testUtils.js';
// --- Mocks ---
const mockChatCreateFn = vi.fn();
@@ -68,6 +69,9 @@ describe('Gemini Client (client.ts)', () => {
beforeEach(async () => {
vi.resetAllMocks();
+ // Disable 429 simulation for tests
+ setSimulate429(false);
+
// Set up the mock for GoogleGenAI constructor and its methods
const MockedGoogleGenAI = vi.mocked(GoogleGenAI);
MockedGoogleGenAI.mockImplementation(() => {
diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts
index 36acb3e8..eb94baed 100644
--- a/packages/core/src/core/client.ts
+++ b/packages/core/src/core/client.ts
@@ -38,6 +38,7 @@ import {
} from './contentGenerator.js';
import { ProxyAgent, setGlobalDispatcher } from 'undici';
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
+import { AuthType } from './contentGenerator.js';
function isThinkingSupported(model: string) {
if (model.startsWith('gemini-2.5')) return true;
@@ -276,7 +277,11 @@ export class GeminiClient {
contents,
});
- const result = await retryWithBackoff(apiCall);
+ const result = await retryWithBackoff(apiCall, {
+ onPersistent429: async (authType?: string) =>
+ await this.handleFlashFallback(authType),
+ authType: this.config.getContentGeneratorConfig()?.authType,
+ });
const text = getResponseText(result);
if (!text) {
@@ -360,7 +365,11 @@ export class GeminiClient {
contents,
});
- const result = await retryWithBackoff(apiCall);
+ const result = await retryWithBackoff(apiCall, {
+ onPersistent429: async (authType?: string) =>
+ await this.handleFlashFallback(authType),
+ authType: this.config.getContentGeneratorConfig()?.authType,
+ });
return result;
} catch (error: unknown) {
if (abortSignal.aborted) {
@@ -489,4 +498,43 @@ export class GeminiClient {
}
: null;
}
+
+ /**
+ * Handles fallback to Flash model when persistent 429 errors occur for OAuth users.
+ * Uses a fallback handler if provided by the config, otherwise returns null.
+ */
+ private async handleFlashFallback(authType?: string): Promise<string | null> {
+ // Only handle fallback for OAuth users
+ if (
+ authType !== AuthType.LOGIN_WITH_GOOGLE_PERSONAL &&
+ authType !== AuthType.LOGIN_WITH_GOOGLE_ENTERPRISE
+ ) {
+ return null;
+ }
+
+ const currentModel = this.model;
+ const fallbackModel = DEFAULT_GEMINI_FLASH_MODEL;
+
+ // Don't fallback if already using Flash model
+ if (currentModel === fallbackModel) {
+ return null;
+ }
+
+ // Check if config has a fallback handler (set by CLI package)
+ const fallbackHandler = this.config.flashFallbackHandler;
+ if (typeof fallbackHandler === 'function') {
+ try {
+ const accepted = await fallbackHandler(currentModel, fallbackModel);
+ if (accepted) {
+ this.config.setModel(fallbackModel);
+ this.model = fallbackModel;
+ return fallbackModel;
+ }
+ } catch (error) {
+ console.warn('Flash fallback handler failed:', error);
+ }
+ }
+
+ return null;
+ }
}
diff --git a/packages/core/src/core/contentGenerator.ts b/packages/core/src/core/contentGenerator.ts
index 058e69a6..c708dad4 100644
--- a/packages/core/src/core/contentGenerator.ts
+++ b/packages/core/src/core/contentGenerator.ts
@@ -51,14 +51,18 @@ export type ContentGeneratorConfig = {
export async function createContentGeneratorConfig(
model: string | undefined,
authType: AuthType | undefined,
+ config?: { getModel?: () => string },
): Promise<ContentGeneratorConfig> {
const geminiApiKey = process.env.GEMINI_API_KEY;
const googleApiKey = process.env.GOOGLE_API_KEY;
const googleCloudProject = process.env.GOOGLE_CLOUD_PROJECT;
const googleCloudLocation = process.env.GOOGLE_CLOUD_LOCATION;
+ // Use runtime model from config if available, otherwise fallback to parameter or default
+ const effectiveModel = config?.getModel?.() || model || DEFAULT_GEMINI_MODEL;
+
const contentGeneratorConfig: ContentGeneratorConfig = {
- model: model || DEFAULT_GEMINI_MODEL,
+ model: effectiveModel,
authType,
};
diff --git a/packages/core/src/core/geminiChat.test.ts b/packages/core/src/core/geminiChat.test.ts
index 0b1ed339..18b3729e 100644
--- a/packages/core/src/core/geminiChat.test.ts
+++ b/packages/core/src/core/geminiChat.test.ts
@@ -14,6 +14,7 @@ import {
} from '@google/genai';
import { GeminiChat } from './geminiChat.js';
import { Config } from '../config/config.js';
+import { setSimulate429 } from '../utils/testUtils.js';
// Mocks
const mockModelsModule = {
@@ -29,6 +30,12 @@ const mockConfig = {
getTelemetryLogPromptsEnabled: () => true,
getUsageStatisticsEnabled: () => true,
getDebugMode: () => false,
+ getContentGeneratorConfig: () => ({
+ authType: 'oauth-personal',
+ model: 'test-model',
+ }),
+ setModel: vi.fn(),
+ flashFallbackHandler: undefined,
} as unknown as Config;
describe('GeminiChat', () => {
@@ -38,6 +45,8 @@ describe('GeminiChat', () => {
beforeEach(() => {
vi.clearAllMocks();
+ // Disable 429 simulation for tests
+ setSimulate429(false);
// Reset history for each test by creating a new instance
chat = new GeminiChat(mockConfig, mockModelsModule, model, config, []);
});
diff --git a/packages/core/src/core/geminiChat.ts b/packages/core/src/core/geminiChat.ts
index e08aaf86..4db13852 100644
--- a/packages/core/src/core/geminiChat.ts
+++ b/packages/core/src/core/geminiChat.ts
@@ -18,7 +18,7 @@ import {
} from '@google/genai';
import { retryWithBackoff } from '../utils/retry.js';
import { isFunctionResponse } from '../utils/messageInspectors.js';
-import { ContentGenerator } from './contentGenerator.js';
+import { ContentGenerator, AuthType } from './contentGenerator.js';
import { Config } from '../config/config.js';
import {
logApiRequest,
@@ -34,6 +34,7 @@ import {
ApiRequestEvent,
ApiResponseEvent,
} from '../telemetry/types.js';
+import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
/**
* Returns true if the response is valid, false otherwise.
@@ -182,6 +183,44 @@ export class GeminiChat {
}
/**
+ * Handles fallback to Flash model when persistent 429 errors occur for OAuth users.
+ * Uses a fallback handler if provided by the config, otherwise returns null.
+ */
+ private async handleFlashFallback(authType?: string): Promise<string | null> {
+ // Only handle fallback for OAuth users
+ if (
+ authType !== AuthType.LOGIN_WITH_GOOGLE_PERSONAL &&
+ authType !== AuthType.LOGIN_WITH_GOOGLE_ENTERPRISE
+ ) {
+ return null;
+ }
+
+ const currentModel = this.model;
+ const fallbackModel = DEFAULT_GEMINI_FLASH_MODEL;
+
+ // Don't fallback if already using Flash model
+ if (currentModel === fallbackModel) {
+ return null;
+ }
+
+ // Check if config has a fallback handler (set by CLI package)
+ const fallbackHandler = this.config.flashFallbackHandler;
+ if (typeof fallbackHandler === 'function') {
+ try {
+ const accepted = await fallbackHandler(currentModel, fallbackModel);
+ if (accepted) {
+ this.config.setModel(fallbackModel);
+ return fallbackModel;
+ }
+ } catch (error) {
+ console.warn('Flash fallback handler failed:', error);
+ }
+ }
+
+ return null;
+ }
+
+ /**
* Sends a message to the model and returns the response.
*
* @remarks
@@ -315,6 +354,9 @@ export class GeminiChat {
}
return false; // Don't retry other errors by default
},
+ onPersistent429: async (authType?: string) =>
+ await this.handleFlashFallback(authType),
+ authType: this.config.getContentGeneratorConfig()?.authType,
});
// Resolve the internal tracking of send completion promise - `sendPromise`
diff --git a/packages/core/src/utils/flashFallback.integration.test.ts b/packages/core/src/utils/flashFallback.integration.test.ts
new file mode 100644
index 00000000..21c40296
--- /dev/null
+++ b/packages/core/src/utils/flashFallback.integration.test.ts
@@ -0,0 +1,144 @@
+/**
+ * @license
+ * Copyright 2025 Google LLC
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+import { describe, it, expect, beforeEach, vi } from 'vitest';
+import { Config } from '../config/config.js';
+import {
+ setSimulate429,
+ disableSimulationAfterFallback,
+ shouldSimulate429,
+ createSimulated429Error,
+ resetRequestCounter,
+} from './testUtils.js';
+import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
+import { retryWithBackoff } from './retry.js';
+import { AuthType } from '../core/contentGenerator.js';
+
+describe('Flash Fallback Integration', () => {
+ let config: Config;
+
+ beforeEach(() => {
+ config = new Config({
+ sessionId: 'test-session',
+ targetDir: '/test',
+ debugMode: false,
+ cwd: '/test',
+ model: 'gemini-2.5-pro',
+ });
+
+ // Reset simulation state for each test
+ setSimulate429(false);
+ resetRequestCounter();
+ });
+
+ it('should automatically accept fallback', async () => {
+ // Set up a minimal flash fallback handler for testing
+ const flashFallbackHandler = async (): Promise<boolean> => true;
+
+ config.setFlashFallbackHandler(flashFallbackHandler);
+
+ // Call the handler directly to test
+ const result = await config.flashFallbackHandler!(
+ 'gemini-2.5-pro',
+ DEFAULT_GEMINI_FLASH_MODEL,
+ );
+
+ // Verify it automatically accepts
+ expect(result).toBe(true);
+ });
+
+ it('should trigger fallback after 3 consecutive 429 errors for OAuth users', async () => {
+ let fallbackCalled = false;
+ let fallbackModel = '';
+
+ // Mock function that simulates exactly 3 429 errors, then succeeds after fallback
+ const mockApiCall = vi
+ .fn()
+ .mockRejectedValueOnce(createSimulated429Error())
+ .mockRejectedValueOnce(createSimulated429Error())
+ .mockRejectedValueOnce(createSimulated429Error())
+ .mockResolvedValueOnce('success after fallback');
+
+ // Mock fallback handler
+ const mockFallbackHandler = vi.fn(async (_authType?: string) => {
+ fallbackCalled = true;
+ fallbackModel = DEFAULT_GEMINI_FLASH_MODEL;
+ return fallbackModel;
+ });
+
+ // Test with OAuth personal auth type, with maxAttempts = 3 to ensure fallback triggers
+ const result = await retryWithBackoff(mockApiCall, {
+ maxAttempts: 3,
+ initialDelayMs: 1,
+ maxDelayMs: 10,
+ shouldRetry: (error: Error) => {
+ const status = (error as Error & { status?: number }).status;
+ return status === 429;
+ },
+ onPersistent429: mockFallbackHandler,
+ authType: AuthType.LOGIN_WITH_GOOGLE_PERSONAL,
+ });
+
+ // Verify fallback was triggered
+ expect(fallbackCalled).toBe(true);
+ expect(fallbackModel).toBe(DEFAULT_GEMINI_FLASH_MODEL);
+ expect(mockFallbackHandler).toHaveBeenCalledWith(
+ AuthType.LOGIN_WITH_GOOGLE_PERSONAL,
+ );
+ expect(result).toBe('success after fallback');
+ // Should have: 3 failures, then fallback triggered, then 1 success after retry reset
+ expect(mockApiCall).toHaveBeenCalledTimes(4);
+ });
+
+ it('should not trigger fallback for API key users', async () => {
+ let fallbackCalled = false;
+
+ // Mock function that simulates 429 errors
+ const mockApiCall = vi.fn().mockRejectedValue(createSimulated429Error());
+
+ // Mock fallback handler
+ const mockFallbackHandler = vi.fn(async () => {
+ fallbackCalled = true;
+ return DEFAULT_GEMINI_FLASH_MODEL;
+ });
+
+ // Test with API key auth type - should not trigger fallback
+ try {
+ await retryWithBackoff(mockApiCall, {
+ maxAttempts: 5,
+ initialDelayMs: 10,
+ maxDelayMs: 100,
+ shouldRetry: (error: Error) => {
+ const status = (error as Error & { status?: number }).status;
+ return status === 429;
+ },
+ onPersistent429: mockFallbackHandler,
+ authType: AuthType.USE_GEMINI, // API key auth type
+ });
+ } catch (error) {
+ // Expected to throw after max attempts
+ expect((error as Error).message).toContain('Rate limit exceeded');
+ }
+
+ // Verify fallback was NOT triggered for API key users
+ expect(fallbackCalled).toBe(false);
+ expect(mockFallbackHandler).not.toHaveBeenCalled();
+ });
+
+ it('should properly disable simulation state after fallback', () => {
+ // Enable simulation
+ setSimulate429(true);
+
+ // Verify simulation is enabled
+ expect(shouldSimulate429()).toBe(true);
+
+ // Disable simulation after fallback
+ disableSimulationAfterFallback();
+
+ // Verify simulation is now disabled
+ expect(shouldSimulate429()).toBe(false);
+ });
+});
diff --git a/packages/core/src/utils/retry.test.ts b/packages/core/src/utils/retry.test.ts
index 4c269987..39f62981 100644
--- a/packages/core/src/utils/retry.test.ts
+++ b/packages/core/src/utils/retry.test.ts
@@ -7,6 +7,7 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import { retryWithBackoff } from './retry.js';
+import { setSimulate429 } from './testUtils.js';
// Define an interface for the error with a status property
interface HttpError extends Error {
@@ -42,10 +43,15 @@ class NonRetryableError extends Error {
describe('retryWithBackoff', () => {
beforeEach(() => {
vi.useFakeTimers();
+ // Disable 429 simulation for tests
+ setSimulate429(false);
+ // Suppress unhandled promise rejection warnings for tests that expect errors
+ console.warn = vi.fn();
});
afterEach(() => {
vi.restoreAllMocks();
+ vi.useRealTimers();
});
it('should return the result on the first attempt if successful', async () => {
@@ -231,4 +237,197 @@ describe('retryWithBackoff', () => {
expect(d).toBeLessThanOrEqual(100 * 1.3);
});
});
+
+ describe('Flash model fallback for OAuth users', () => {
+ it('should trigger fallback for OAuth personal users after persistent 429 errors', async () => {
+ const fallbackCallback = vi.fn().mockResolvedValue('gemini-2.5-flash');
+
+ let fallbackOccurred = false;
+ const mockFn = vi.fn().mockImplementation(async () => {
+ if (!fallbackOccurred) {
+ const error: HttpError = new Error('Rate limit exceeded');
+ error.status = 429;
+ throw error;
+ }
+ return 'success';
+ });
+
+ const promise = retryWithBackoff(mockFn, {
+ maxAttempts: 3,
+ initialDelayMs: 100,
+ onPersistent429: async (authType?: string) => {
+ fallbackOccurred = true;
+ return await fallbackCallback(authType);
+ },
+ authType: 'oauth-personal',
+ });
+
+ // Advance all timers to complete retries
+ await vi.runAllTimersAsync();
+
+ // Should succeed after fallback
+ await expect(promise).resolves.toBe('success');
+
+ // Verify callback was called with correct auth type
+ expect(fallbackCallback).toHaveBeenCalledWith('oauth-personal');
+
+ // Should retry again after fallback
+ expect(mockFn).toHaveBeenCalledTimes(4); // 3 initial attempts + 1 after fallback
+ });
+
+ it('should trigger fallback for OAuth enterprise users after persistent 429 errors', async () => {
+ const fallbackCallback = vi.fn().mockResolvedValue('gemini-2.5-flash');
+
+ let fallbackOccurred = false;
+ const mockFn = vi.fn().mockImplementation(async () => {
+ if (!fallbackOccurred) {
+ const error: HttpError = new Error('Rate limit exceeded');
+ error.status = 429;
+ throw error;
+ }
+ return 'success';
+ });
+
+ const promise = retryWithBackoff(mockFn, {
+ maxAttempts: 3,
+ initialDelayMs: 100,
+ onPersistent429: async (authType?: string) => {
+ fallbackOccurred = true;
+ return await fallbackCallback(authType);
+ },
+ authType: 'oauth-enterprise',
+ });
+
+ await vi.runAllTimersAsync();
+
+ await expect(promise).resolves.toBe('success');
+ expect(fallbackCallback).toHaveBeenCalledWith('oauth-enterprise');
+ });
+
+ it('should NOT trigger fallback for API key users', async () => {
+ const fallbackCallback = vi.fn();
+
+ const mockFn = vi.fn(async () => {
+ const error: HttpError = new Error('Rate limit exceeded');
+ error.status = 429;
+ throw error;
+ });
+
+ const promise = retryWithBackoff(mockFn, {
+ maxAttempts: 3,
+ initialDelayMs: 100,
+ onPersistent429: fallbackCallback,
+ authType: 'gemini-api-key',
+ });
+
+ // Handle the promise properly to avoid unhandled rejections
+ const resultPromise = promise.catch((error) => error);
+ await vi.runAllTimersAsync();
+ const result = await resultPromise;
+
+ // Should fail after all retries without fallback
+ expect(result).toBeInstanceOf(Error);
+ expect(result.message).toBe('Rate limit exceeded');
+
+ // Callback should not be called for API key users
+ expect(fallbackCallback).not.toHaveBeenCalled();
+ });
+
+ it('should reset attempt counter and continue after successful fallback', async () => {
+ let fallbackCalled = false;
+ const fallbackCallback = vi.fn().mockImplementation(async () => {
+ fallbackCalled = true;
+ return 'gemini-2.5-flash';
+ });
+
+ const mockFn = vi.fn().mockImplementation(async () => {
+ if (!fallbackCalled) {
+ const error: HttpError = new Error('Rate limit exceeded');
+ error.status = 429;
+ throw error;
+ }
+ return 'success';
+ });
+
+ const promise = retryWithBackoff(mockFn, {
+ maxAttempts: 3,
+ initialDelayMs: 100,
+ onPersistent429: fallbackCallback,
+ authType: 'oauth-personal',
+ });
+
+ await vi.runAllTimersAsync();
+
+ await expect(promise).resolves.toBe('success');
+ expect(fallbackCallback).toHaveBeenCalledOnce();
+ });
+
+ it('should continue with original error if fallback is rejected', async () => {
+ const fallbackCallback = vi.fn().mockResolvedValue(null); // User rejected fallback
+
+ const mockFn = vi.fn(async () => {
+ const error: HttpError = new Error('Rate limit exceeded');
+ error.status = 429;
+ throw error;
+ });
+
+ const promise = retryWithBackoff(mockFn, {
+ maxAttempts: 3,
+ initialDelayMs: 100,
+ onPersistent429: fallbackCallback,
+ authType: 'oauth-personal',
+ });
+
+ // Handle the promise properly to avoid unhandled rejections
+ const resultPromise = promise.catch((error) => error);
+ await vi.runAllTimersAsync();
+ const result = await resultPromise;
+
+ // Should fail with original error when fallback is rejected
+ expect(result).toBeInstanceOf(Error);
+ expect(result.message).toBe('Rate limit exceeded');
+ expect(fallbackCallback).toHaveBeenCalledWith('oauth-personal');
+ });
+
+ it('should handle mixed error types (only count consecutive 429s)', async () => {
+ const fallbackCallback = vi.fn().mockResolvedValue('gemini-2.5-flash');
+ let attempts = 0;
+ let fallbackOccurred = false;
+
+ const mockFn = vi.fn().mockImplementation(async () => {
+ attempts++;
+ if (fallbackOccurred) {
+ return 'success';
+ }
+ if (attempts === 1) {
+ // First attempt: 500 error (resets consecutive count)
+ const error: HttpError = new Error('Server error');
+ error.status = 500;
+ throw error;
+ } else {
+ // Remaining attempts: 429 errors
+ const error: HttpError = new Error('Rate limit exceeded');
+ error.status = 429;
+ throw error;
+ }
+ });
+
+ const promise = retryWithBackoff(mockFn, {
+ maxAttempts: 5,
+ initialDelayMs: 100,
+ onPersistent429: async (authType?: string) => {
+ fallbackOccurred = true;
+ return await fallbackCallback(authType);
+ },
+ authType: 'oauth-personal',
+ });
+
+ await vi.runAllTimersAsync();
+
+ await expect(promise).resolves.toBe('success');
+
+ // Should trigger fallback after 4 consecutive 429s (attempts 2-5)
+ expect(fallbackCallback).toHaveBeenCalledWith('oauth-personal');
+ });
+ });
});
diff --git a/packages/core/src/utils/retry.ts b/packages/core/src/utils/retry.ts
index 1e7d5bcb..e0fc4ced 100644
--- a/packages/core/src/utils/retry.ts
+++ b/packages/core/src/utils/retry.ts
@@ -4,11 +4,15 @@
* SPDX-License-Identifier: Apache-2.0
*/
+import { AuthType } from '../core/contentGenerator.js';
+
export interface RetryOptions {
maxAttempts: number;
initialDelayMs: number;
maxDelayMs: number;
shouldRetry: (error: Error) => boolean;
+ onPersistent429?: (authType?: string) => Promise<string | null>;
+ authType?: string;
}
const DEFAULT_RETRY_OPTIONS: RetryOptions = {
@@ -59,29 +63,69 @@ export async function retryWithBackoff<T>(
fn: () => Promise<T>,
options?: Partial<RetryOptions>,
): Promise<T> {
- const { maxAttempts, initialDelayMs, maxDelayMs, shouldRetry } = {
+ const {
+ maxAttempts,
+ initialDelayMs,
+ maxDelayMs,
+ shouldRetry,
+ onPersistent429,
+ authType,
+ } = {
...DEFAULT_RETRY_OPTIONS,
...options,
};
let attempt = 0;
let currentDelay = initialDelayMs;
+ let consecutive429Count = 0;
while (attempt < maxAttempts) {
attempt++;
try {
return await fn();
} catch (error) {
+ const errorStatus = getErrorStatus(error);
+
+ // Track consecutive 429 errors
+ if (errorStatus === 429) {
+ consecutive429Count++;
+ } else {
+ consecutive429Count = 0;
+ }
+
+ // Check if we've exhausted retries or shouldn't retry
if (attempt >= maxAttempts || !shouldRetry(error as Error)) {
+ // If we have persistent 429s and a fallback callback for OAuth
+ if (
+ consecutive429Count >= 3 &&
+ onPersistent429 &&
+ (authType === AuthType.LOGIN_WITH_GOOGLE_PERSONAL ||
+ authType === AuthType.LOGIN_WITH_GOOGLE_ENTERPRISE)
+ ) {
+ try {
+ const fallbackModel = await onPersistent429(authType);
+ if (fallbackModel) {
+ // Reset attempt counter and try with new model
+ attempt = 0;
+ consecutive429Count = 0;
+ currentDelay = initialDelayMs;
+ continue;
+ }
+ } catch (fallbackError) {
+ // If fallback fails, continue with original error
+ console.warn('Fallback to Flash model failed:', fallbackError);
+ }
+ }
throw error;
}
- const { delayDurationMs, errorStatus } = getDelayDurationAndStatus(error);
+ const { delayDurationMs, errorStatus: delayErrorStatus } =
+ getDelayDurationAndStatus(error);
if (delayDurationMs > 0) {
// Respect Retry-After header if present and parsed
console.warn(
- `Attempt ${attempt} failed with status ${errorStatus ?? 'unknown'}. Retrying after explicit delay of ${delayDurationMs}ms...`,
+ `Attempt ${attempt} failed with status ${delayErrorStatus ?? 'unknown'}. Retrying after explicit delay of ${delayDurationMs}ms...`,
error,
);
await delay(delayDurationMs);
diff --git a/packages/core/src/utils/testUtils.ts b/packages/core/src/utils/testUtils.ts
new file mode 100644
index 00000000..a0010b10
--- /dev/null
+++ b/packages/core/src/utils/testUtils.ts
@@ -0,0 +1,87 @@
+/**
+ * @license
+ * Copyright 2025 Google LLC
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+/**
+ * Testing utilities for simulating 429 errors in unit tests
+ */
+
+let requestCounter = 0;
+let simulate429Enabled = false;
+let simulate429AfterRequests = 0;
+let simulate429ForAuthType: string | undefined;
+let fallbackOccurred = false;
+
+/**
+ * Check if we should simulate a 429 error for the current request
+ */
+export function shouldSimulate429(authType?: string): boolean {
+ if (!simulate429Enabled || fallbackOccurred) {
+ return false;
+ }
+
+ // If auth type filter is set, only simulate for that auth type
+ if (simulate429ForAuthType && authType !== simulate429ForAuthType) {
+ return false;
+ }
+
+ requestCounter++;
+
+ // If afterRequests is set, only simulate after that many requests
+ if (simulate429AfterRequests > 0) {
+ return requestCounter > simulate429AfterRequests;
+ }
+
+ // Otherwise, simulate for every request
+ return true;
+}
+
+/**
+ * Reset the request counter (useful for tests)
+ */
+export function resetRequestCounter(): void {
+ requestCounter = 0;
+}
+
+/**
+ * Disable 429 simulation after successful fallback
+ */
+export function disableSimulationAfterFallback(): void {
+ fallbackOccurred = true;
+}
+
+/**
+ * Create a simulated 429 error response
+ */
+export function createSimulated429Error(): Error {
+ const error = new Error('Rate limit exceeded (simulated)') as Error & {
+ status: number;
+ };
+ error.status = 429;
+ return error;
+}
+
+/**
+ * Reset simulation state when switching auth methods
+ */
+export function resetSimulationState(): void {
+ fallbackOccurred = false;
+ resetRequestCounter();
+}
+
+/**
+ * Enable/disable 429 simulation programmatically (for tests)
+ */
+export function setSimulate429(
+ enabled: boolean,
+ afterRequests = 0,
+ forAuthType?: string,
+): void {
+ simulate429Enabled = enabled;
+ simulate429AfterRequests = afterRequests;
+ simulate429ForAuthType = forAuthType;
+ fallbackOccurred = false; // Reset fallback state when simulation is re-enabled
+ resetRequestCounter();
+}