summaryrefslogtreecommitdiff
path: root/packages/core
diff options
context:
space:
mode:
authorBryan Morgan <[email protected]>2025-06-25 21:45:38 -0400
committerGitHub <[email protected]>2025-06-26 01:45:38 +0000
commitbb797ded7d15003dfec7a68ff82764d7a2c44458 (patch)
treeed7af7f2dae0d9a838610d3248ee957b813b07c3 /packages/core
parentb6b9923dc3b80a73fdee3a3ccd6070c8cfb551cd (diff)
429 fix (#1668)
Diffstat (limited to 'packages/core')
-rw-r--r--packages/core/src/core/client.ts1
-rw-r--r--packages/core/src/core/geminiChat.test.ts36
-rw-r--r--packages/core/src/core/geminiChat.ts38
-rw-r--r--packages/core/src/utils/nextSpeakerChecker.test.ts1
-rw-r--r--packages/core/src/utils/retry.test.ts2
-rw-r--r--packages/core/src/utils/retry.ts44
6 files changed, 72 insertions, 50 deletions
diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts
index 6170f319..df655b59 100644
--- a/packages/core/src/core/client.ts
+++ b/packages/core/src/core/client.ts
@@ -196,7 +196,6 @@ export class GeminiClient {
return new GeminiChat(
this.config,
this.getContentGenerator(),
- this.model,
{
systemInstruction,
...generateContentConfigWithThinking,
diff --git a/packages/core/src/core/geminiChat.test.ts b/packages/core/src/core/geminiChat.test.ts
index 18b3729e..bfaeb8f6 100644
--- a/packages/core/src/core/geminiChat.test.ts
+++ b/packages/core/src/core/geminiChat.test.ts
@@ -25,34 +25,36 @@ const mockModelsModule = {
batchEmbedContents: vi.fn(),
} as unknown as Models;
-const mockConfig = {
- getSessionId: () => 'test-session-id',
- getTelemetryLogPromptsEnabled: () => true,
- getUsageStatisticsEnabled: () => true,
- getDebugMode: () => false,
- getContentGeneratorConfig: () => ({
- authType: 'oauth-personal',
- model: 'test-model',
- }),
- setModel: vi.fn(),
- flashFallbackHandler: undefined,
-} as unknown as Config;
-
describe('GeminiChat', () => {
let chat: GeminiChat;
- const model = 'gemini-pro';
+ let mockConfig: Config;
const config: GenerateContentConfig = {};
beforeEach(() => {
vi.clearAllMocks();
+ mockConfig = {
+ getSessionId: () => 'test-session-id',
+ getTelemetryLogPromptsEnabled: () => true,
+ getUsageStatisticsEnabled: () => true,
+ getDebugMode: () => false,
+ getContentGeneratorConfig: () => ({
+ authType: 'oauth-personal',
+ model: 'test-model',
+ }),
+ getModel: vi.fn().mockReturnValue('gemini-pro'),
+ setModel: vi.fn(),
+ flashFallbackHandler: undefined,
+ } as unknown as Config;
+
// Disable 429 simulation for tests
setSimulate429(false);
// Reset history for each test by creating a new instance
- chat = new GeminiChat(mockConfig, mockModelsModule, model, config, []);
+ chat = new GeminiChat(mockConfig, mockModelsModule, config, []);
});
afterEach(() => {
vi.restoreAllMocks();
+ vi.resetAllMocks();
});
describe('sendMessage', () => {
@@ -203,7 +205,7 @@ describe('GeminiChat', () => {
chat.recordHistory(userInput, newModelOutput); // userInput here is for the *next* turn, but history is already primed
// Reset and set up a more realistic scenario for merging with existing history
- chat = new GeminiChat(mockConfig, mockModelsModule, model, config, []);
+ chat = new GeminiChat(mockConfig, mockModelsModule, config, []);
const firstUserInput: Content = {
role: 'user',
parts: [{ text: 'First user input' }],
@@ -246,7 +248,7 @@ describe('GeminiChat', () => {
role: 'model',
parts: [{ text: 'Initial model answer.' }],
};
- chat = new GeminiChat(mockConfig, mockModelsModule, model, config, [
+ chat = new GeminiChat(mockConfig, mockModelsModule, config, [
initialUser,
initialModel,
]);
diff --git a/packages/core/src/core/geminiChat.ts b/packages/core/src/core/geminiChat.ts
index ce5accf4..ac4f4898 100644
--- a/packages/core/src/core/geminiChat.ts
+++ b/packages/core/src/core/geminiChat.ts
@@ -138,7 +138,6 @@ export class GeminiChat {
constructor(
private readonly config: Config,
private readonly contentGenerator: ContentGenerator,
- private readonly model: string,
private readonly generationConfig: GenerateContentConfig = {},
private history: Content[] = [],
) {
@@ -168,7 +167,12 @@ export class GeminiChat {
): Promise<void> {
logApiResponse(
this.config,
- new ApiResponseEvent(this.model, durationMs, usageMetadata, responseText),
+ new ApiResponseEvent(
+ this.config.getModel(),
+ durationMs,
+ usageMetadata,
+ responseText,
+ ),
);
}
@@ -178,7 +182,12 @@ export class GeminiChat {
logApiError(
this.config,
- new ApiErrorEvent(this.model, errorMessage, durationMs, errorType),
+ new ApiErrorEvent(
+ this.config.getModel(),
+ errorMessage,
+ durationMs,
+ errorType,
+ ),
);
}
@@ -192,7 +201,7 @@ export class GeminiChat {
return null;
}
- const currentModel = this.model;
+ const currentModel = this.config.getModel();
const fallbackModel = DEFAULT_GEMINI_FLASH_MODEL;
// Don't fallback if already using Flash model
@@ -244,7 +253,7 @@ export class GeminiChat {
const userContent = createUserContent(params.message);
const requestContents = this.getHistory(true).concat(userContent);
- this._logApiRequest(requestContents, this.model);
+ this._logApiRequest(requestContents, this.config.getModel());
const startTime = Date.now();
let response: GenerateContentResponse;
@@ -252,12 +261,23 @@ export class GeminiChat {
try {
const apiCall = () =>
this.contentGenerator.generateContent({
- model: this.model,
+ model: this.config.getModel() || DEFAULT_GEMINI_FLASH_MODEL,
contents: requestContents,
config: { ...this.generationConfig, ...params.config },
});
- response = await retryWithBackoff(apiCall);
+ response = await retryWithBackoff(apiCall, {
+ shouldRetry: (error: Error) => {
+ if (error && error.message) {
+ if (error.message.includes('429')) return true;
+ if (error.message.match(/5\d{2}/)) return true;
+ }
+ return false;
+ },
+ onPersistent429: async (authType?: string) =>
+ await this.handleFlashFallback(authType),
+ authType: this.config.getContentGeneratorConfig()?.authType,
+ });
const durationMs = Date.now() - startTime;
await this._logApiResponse(
durationMs,
@@ -326,14 +346,14 @@ export class GeminiChat {
await this.sendPromise;
const userContent = createUserContent(params.message);
const requestContents = this.getHistory(true).concat(userContent);
- this._logApiRequest(requestContents, this.model);
+ this._logApiRequest(requestContents, this.config.getModel());
const startTime = Date.now();
try {
const apiCall = () =>
this.contentGenerator.generateContentStream({
- model: this.model,
+ model: this.config.getModel(),
contents: requestContents,
config: { ...this.generationConfig, ...params.config },
});
diff --git a/packages/core/src/utils/nextSpeakerChecker.test.ts b/packages/core/src/utils/nextSpeakerChecker.test.ts
index 83ce97fd..475b5662 100644
--- a/packages/core/src/utils/nextSpeakerChecker.test.ts
+++ b/packages/core/src/utils/nextSpeakerChecker.test.ts
@@ -71,7 +71,6 @@ describe('checkNextSpeaker', () => {
chatInstance = new GeminiChat(
mockConfigInstance,
mockModelsInstance, // This is the instance returned by mockGoogleGenAIInstance.getGenerativeModel
- 'gemini-pro', // model name
{},
[], // initial history
);
diff --git a/packages/core/src/utils/retry.test.ts b/packages/core/src/utils/retry.test.ts
index 1988c02a..a0294c31 100644
--- a/packages/core/src/utils/retry.test.ts
+++ b/packages/core/src/utils/retry.test.ts
@@ -272,7 +272,7 @@ describe('retryWithBackoff', () => {
expect(fallbackCallback).toHaveBeenCalledWith('oauth-personal');
// Should retry again after fallback
- expect(mockFn).toHaveBeenCalledTimes(4); // 3 initial attempts + 1 after fallback
+ expect(mockFn).toHaveBeenCalledTimes(3); // 2 initial attempts + 1 after fallback
});
it('should NOT trigger fallback for API key users', async () => {
diff --git a/packages/core/src/utils/retry.ts b/packages/core/src/utils/retry.ts
index ebe18510..372a7976 100644
--- a/packages/core/src/utils/retry.ts
+++ b/packages/core/src/utils/retry.ts
@@ -67,9 +67,9 @@ export async function retryWithBackoff<T>(
maxAttempts,
initialDelayMs,
maxDelayMs,
- shouldRetry,
onPersistent429,
authType,
+ shouldRetry,
} = {
...DEFAULT_RETRY_OPTIONS,
...options,
@@ -93,28 +93,30 @@ export async function retryWithBackoff<T>(
consecutive429Count = 0;
}
- // Check if we've exhausted retries or shouldn't retry
- if (attempt >= maxAttempts || !shouldRetry(error as Error)) {
- // If we have persistent 429s and a fallback callback for OAuth
- if (
- consecutive429Count >= 2 &&
- onPersistent429 &&
- authType === AuthType.LOGIN_WITH_GOOGLE_PERSONAL
- ) {
- try {
- const fallbackModel = await onPersistent429(authType);
- if (fallbackModel) {
- // Reset attempt counter and try with new model
- attempt = 0;
- consecutive429Count = 0;
- currentDelay = initialDelayMs;
- continue;
- }
- } catch (fallbackError) {
- // If fallback fails, continue with original error
- console.warn('Fallback to Flash model failed:', fallbackError);
+ // If we have persistent 429s and a fallback callback for OAuth
+ if (
+ consecutive429Count >= 2 &&
+ onPersistent429 &&
+ authType === AuthType.LOGIN_WITH_GOOGLE_PERSONAL
+ ) {
+ try {
+ const fallbackModel = await onPersistent429(authType);
+ if (fallbackModel) {
+ // Reset attempt counter and try with new model
+ attempt = 0;
+ consecutive429Count = 0;
+ currentDelay = initialDelayMs;
+ // With the model updated, we continue to the next attempt
+ continue;
}
+ } catch (fallbackError) {
+ // If fallback fails, continue with original error
+ console.warn('Fallback to Flash model failed:', fallbackError);
}
+ }
+
+ // Check if we've exhausted retries or shouldn't retry
+ if (attempt >= maxAttempts || !shouldRetry(error as Error)) {
throw error;
}