summaryrefslogtreecommitdiff
path: root/packages/cli/src
diff options
context:
space:
mode:
authorAllen Hutchison <[email protected]>2025-06-02 13:55:54 -0700
committerGitHub <[email protected]>2025-06-02 13:55:54 -0700
commit7f20425c98d5adb5531e6c33ed92975b71b34c90 (patch)
treed15a9489520c3ebe9f7b22b998803a0b7fb07a8c /packages/cli/src
parent59b6267b2f3f5d971c10eeaaf9c0e7e82f10cf02 (diff)
feat(cli): add pro model availability check and fallback to flash (#608)
Diffstat (limited to 'packages/cli/src')
-rw-r--r--packages/cli/src/config/config.test.ts16
-rw-r--r--packages/cli/src/config/config.ts46
-rw-r--r--packages/cli/src/config/settings.ts2
-rw-r--r--packages/cli/src/gemini.tsx18
-rw-r--r--packages/cli/src/ui/hooks/useGeminiStream.ts2
-rw-r--r--packages/cli/src/utils/modelCheck.test.ts179
-rw-r--r--packages/cli/src/utils/modelCheck.ts75
7 files changed, 320 insertions, 18 deletions
diff --git a/packages/cli/src/config/config.test.ts b/packages/cli/src/config/config.test.ts
index 9f288372..a39278bc 100644
--- a/packages/cli/src/config/config.test.ts
+++ b/packages/cli/src/config/config.test.ts
@@ -82,29 +82,29 @@ describe('loadCliConfig', () => {
it('should set showMemoryUsage to true when --memory flag is present', async () => {
process.argv = ['node', 'script.js', '--show_memory_usage'];
const settings: Settings = {};
- const config = await loadCliConfig(settings);
- expect(config.getShowMemoryUsage()).toBe(true);
+ const result = await loadCliConfig(settings);
+ expect(result.config.getShowMemoryUsage()).toBe(true);
});
it('should set showMemoryUsage to false when --memory flag is not present', async () => {
process.argv = ['node', 'script.js'];
const settings: Settings = {};
- const config = await loadCliConfig(settings);
- expect(config.getShowMemoryUsage()).toBe(false);
+ const result = await loadCliConfig(settings);
+ expect(result.config.getShowMemoryUsage()).toBe(false);
});
it('should set showMemoryUsage to false by default from settings if CLI flag is not present', async () => {
process.argv = ['node', 'script.js'];
const settings: Settings = { showMemoryUsage: false };
- const config = await loadCliConfig(settings);
- expect(config.getShowMemoryUsage()).toBe(false);
+ const result = await loadCliConfig(settings);
+ expect(result.config.getShowMemoryUsage()).toBe(false);
});
it('should prioritize CLI flag over settings for showMemoryUsage (CLI true, settings false)', async () => {
process.argv = ['node', 'script.js', '--show_memory_usage'];
const settings: Settings = { showMemoryUsage: false };
- const config = await loadCliConfig(settings);
- expect(config.getShowMemoryUsage()).toBe(true);
+ const result = await loadCliConfig(settings);
+ expect(result.config.getShowMemoryUsage()).toBe(true);
});
});
diff --git a/packages/cli/src/config/config.ts b/packages/cli/src/config/config.ts
index ee1c9d36..2429ad64 100644
--- a/packages/cli/src/config/config.ts
+++ b/packages/cli/src/config/config.ts
@@ -19,6 +19,10 @@ import {
} from '@gemini-code/core';
import { Settings } from './settings.js';
import { readPackageUp } from 'read-package-up';
+import {
+ getEffectiveModel,
+ type EffectiveModelCheckResult,
+} from '../utils/modelCheck.js';
// Simple console logger for now - replace with actual logger if available
const logger = {
@@ -30,7 +34,8 @@ const logger = {
error: (...args: any[]) => console.error('[ERROR]', ...args),
};
-const DEFAULT_GEMINI_MODEL = 'gemini-2.5-pro-preview-05-06';
+export const DEFAULT_GEMINI_MODEL = 'gemini-2.5-pro-preview-05-06';
+export const DEFAULT_GEMINI_FLASH_MODEL = 'gemini-2.5-flash-preview-05-20';
interface CliArgs {
model: string | undefined;
@@ -114,7 +119,16 @@ export async function loadHierarchicalGeminiMemory(
return loadServerHierarchicalMemory(currentWorkingDirectory, debugMode);
}
-export async function loadCliConfig(settings: Settings): Promise<Config> {
+export interface LoadCliConfigResult {
+ config: Config;
+ modelWasSwitched: boolean;
+ originalModelBeforeSwitch?: string;
+ finalModel: string;
+}
+
+export async function loadCliConfig(
+ settings: Settings,
+): Promise<LoadCliConfigResult> {
loadEnvironment();
const geminiApiKey = process.env.GEMINI_API_KEY;
@@ -164,9 +178,27 @@ export async function loadCliConfig(settings: Settings): Promise<Config> {
const apiKeyForServer = geminiApiKey || googleApiKey || '';
const useVertexAI = hasGeminiApiKey ? false : undefined;
+ let modelToUse = argv.model || DEFAULT_GEMINI_MODEL;
+ let modelSwitched = false;
+ let originalModel: string | undefined = undefined;
+
+ if (apiKeyForServer) {
+ const checkResult: EffectiveModelCheckResult = await getEffectiveModel(
+ apiKeyForServer,
+ modelToUse,
+ );
+ if (checkResult.switched) {
+ modelSwitched = true;
+ originalModel = checkResult.originalModelIfSwitched;
+ modelToUse = checkResult.effectiveModel;
+ }
+ } else {
+ // logger.debug('API key not available during config load. Skipping model availability check.');
+ }
+
const configParams: ConfigParameters = {
apiKey: apiKeyForServer,
- model: argv.model || DEFAULT_GEMINI_MODEL,
+ model: modelToUse,
sandbox: argv.sandbox ?? settings.sandbox ?? argv.yolo ?? false,
targetDir: process.cwd(),
debugMode,
@@ -186,7 +218,13 @@ export async function loadCliConfig(settings: Settings): Promise<Config> {
argv.show_memory_usage || settings.showMemoryUsage || false,
};
- return createServerConfig(configParams);
+ const config = createServerConfig(configParams);
+ return {
+ config,
+ modelWasSwitched: modelSwitched,
+ originalModelBeforeSwitch: originalModel,
+ finalModel: modelToUse,
+ };
}
async function createUserAgent(): Promise<string> {
diff --git a/packages/cli/src/config/settings.ts b/packages/cli/src/config/settings.ts
index 5d51ba15..6c14a6dc 100644
--- a/packages/cli/src/config/settings.ts
+++ b/packages/cli/src/config/settings.ts
@@ -7,7 +7,7 @@
import * as fs from 'fs';
import * as path from 'path';
import { homedir } from 'os';
-import { MCPServerConfig } from '@gemini-code/core/src/config/config.js';
+import { MCPServerConfig } from '@gemini-code/core';
import stripJsonComments from 'strip-json-comments';
import { DefaultLight } from '../ui/themes/default-light.js';
import { DefaultDark } from '../ui/themes/default.js';
diff --git a/packages/cli/src/gemini.tsx b/packages/cli/src/gemini.tsx
index 07551813..f8cc77b6 100644
--- a/packages/cli/src/gemini.tsx
+++ b/packages/cli/src/gemini.tsx
@@ -50,11 +50,19 @@ async function main() {
console.warn(
'GEMINI_CODE_SANDBOX_IMAGE is deprecated. Use GEMINI_SANDBOX_IMAGE_NAME instead.',
);
- process.env.GEMINI_SANDBOX_IMAGE = process.env.GEMINI_CODE_SANDBOX_IMAGE;
+ process.env.GEMINI_SANDBOX_IMAGE_NAME =
+ process.env.GEMINI_CODE_SANDBOX_IMAGE; // Corrected to GEMINI_SANDBOX_IMAGE_NAME
}
const settings = loadSettings(process.cwd());
- const config = await loadCliConfig(settings.merged);
+ const { config, modelWasSwitched, originalModelBeforeSwitch, finalModel } =
+ await loadCliConfig(settings.merged);
+
+ if (modelWasSwitched && originalModelBeforeSwitch) {
+ console.log(
+ `[INFO] Your configured model (${originalModelBeforeSwitch}) was temporarily unavailable. Switched to ${finalModel} for this session.`,
+ );
+ }
if (settings.merged.theme) {
if (!themeManager.setActiveTheme(settings.merged.theme)) {
@@ -128,8 +136,10 @@ async function main() {
...settings.merged,
coreTools: nonInteractiveTools,
};
- const nonInteractiveConfig = await loadCliConfig(nonInteractiveSettings);
- await runNonInteractive(nonInteractiveConfig, input);
+ const nonInteractiveConfigResult = await loadCliConfig(
+ nonInteractiveSettings,
+ ); // Ensure config is reloaded with non-interactive tools
+ await runNonInteractive(nonInteractiveConfigResult.config, input);
}
// --- Global Unhandled Rejection Handler ---
diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts
index b6ef1481..423f3489 100644
--- a/packages/cli/src/ui/hooks/useGeminiStream.ts
+++ b/packages/cli/src/ui/hooks/useGeminiStream.ts
@@ -17,6 +17,7 @@ import {
Config,
MessageSenderType,
ToolCallRequestInfo,
+ GeminiChat,
} from '@gemini-code/core';
import { type PartListUnion } from '@google/genai';
import {
@@ -40,7 +41,6 @@ import {
TrackedCompletedToolCall,
TrackedCancelledToolCall,
} from './useReactToolScheduler.js';
-import { GeminiChat } from '@gemini-code/core/src/core/geminiChat.js';
export function mergePartListUnions(list: PartListUnion[]): PartListUnion {
const resultParts: PartListUnion = [];
diff --git a/packages/cli/src/utils/modelCheck.test.ts b/packages/cli/src/utils/modelCheck.test.ts
new file mode 100644
index 00000000..3b1cded8
--- /dev/null
+++ b/packages/cli/src/utils/modelCheck.test.ts
@@ -0,0 +1,179 @@
+/**
+ * @license
+ * Copyright 2025 Google LLC
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
+import {
+ getEffectiveModel,
+ type EffectiveModelCheckResult,
+} from './modelCheck.js';
+import {
+ DEFAULT_GEMINI_MODEL,
+ DEFAULT_GEMINI_FLASH_MODEL,
+} from '../config/config.js';
+
+// Mock global fetch
+global.fetch = vi.fn();
+
+// Mock AbortController
+const mockAbort = vi.fn();
+global.AbortController = vi.fn(() => ({
+ signal: { aborted: false }, // Start with not aborted
+ abort: mockAbort,
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
+})) as any;
+
+describe('getEffectiveModel', () => {
+ const apiKey = 'test-api-key';
+
+ beforeEach(() => {
+ vi.useFakeTimers();
+ vi.clearAllMocks();
+ // Reset signal for each test if AbortController mock is more complex
+ global.AbortController = vi.fn(() => ({
+ signal: { aborted: false },
+ abort: mockAbort,
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
+ })) as any;
+ });
+
+ afterEach(() => {
+ vi.restoreAllMocks();
+ vi.useRealTimers();
+ });
+
+ describe('when currentConfiguredModel is not DEFAULT_GEMINI_MODEL', () => {
+ it('should return the currentConfiguredModel and switched: false without fetching', async () => {
+ const customModel = 'custom-model-name';
+ const result = await getEffectiveModel(apiKey, customModel);
+ expect(result).toEqual({
+ effectiveModel: customModel,
+ switched: false,
+ });
+ expect(fetch).not.toHaveBeenCalled();
+ });
+ });
+
+ describe('when currentConfiguredModel is DEFAULT_GEMINI_MODEL', () => {
+ it('should switch to DEFAULT_GEMINI_FLASH_MODEL if fetch returns 429', async () => {
+ (fetch as vi.Mock).mockResolvedValueOnce({
+ ok: false,
+ status: 429,
+ });
+ const result: EffectiveModelCheckResult = await getEffectiveModel(
+ apiKey,
+ DEFAULT_GEMINI_MODEL,
+ );
+ expect(result).toEqual({
+ effectiveModel: DEFAULT_GEMINI_FLASH_MODEL,
+ switched: true,
+ originalModelIfSwitched: DEFAULT_GEMINI_MODEL,
+ });
+ expect(fetch).toHaveBeenCalledTimes(1);
+ expect(fetch).toHaveBeenCalledWith(
+ `https://generativelanguage.googleapis.com/v1beta/models/${DEFAULT_GEMINI_MODEL}:generateContent?key=${apiKey}`,
+ expect.any(Object),
+ );
+ });
+
+ it('should return DEFAULT_GEMINI_MODEL if fetch returns 200', async () => {
+ (fetch as vi.Mock).mockResolvedValueOnce({
+ ok: true,
+ status: 200,
+ });
+ const result = await getEffectiveModel(apiKey, DEFAULT_GEMINI_MODEL);
+ expect(result).toEqual({
+ effectiveModel: DEFAULT_GEMINI_MODEL,
+ switched: false,
+ });
+ expect(fetch).toHaveBeenCalledTimes(1);
+ });
+
+ it('should return DEFAULT_GEMINI_MODEL if fetch returns a non-429 error status (e.g., 500)', async () => {
+ (fetch as vi.Mock).mockResolvedValueOnce({
+ ok: false,
+ status: 500,
+ });
+ const result = await getEffectiveModel(apiKey, DEFAULT_GEMINI_MODEL);
+ expect(result).toEqual({
+ effectiveModel: DEFAULT_GEMINI_MODEL,
+ switched: false,
+ });
+ expect(fetch).toHaveBeenCalledTimes(1);
+ });
+
+ it('should return DEFAULT_GEMINI_MODEL if fetch throws a network error', async () => {
+ (fetch as vi.Mock).mockRejectedValueOnce(new Error('Network error'));
+ const result = await getEffectiveModel(apiKey, DEFAULT_GEMINI_MODEL);
+ expect(result).toEqual({
+ effectiveModel: DEFAULT_GEMINI_MODEL,
+ switched: false,
+ });
+ expect(fetch).toHaveBeenCalledTimes(1);
+ });
+
+ it('should return DEFAULT_GEMINI_MODEL if fetch times out', async () => {
+ // Simulate AbortController's signal changing and fetch throwing AbortError
+ const abortControllerInstance = {
+ signal: { aborted: false }, // mutable signal
+ abort: vi.fn(() => {
+ abortControllerInstance.signal.aborted = true; // Use abortControllerInstance
+ }),
+ };
+ (global.AbortController as vi.Mock).mockImplementationOnce(
+ () => abortControllerInstance,
+ );
+
+ (fetch as vi.Mock).mockImplementationOnce(
+ async ({ signal }: { signal: AbortSignal }) => {
+ // Simulate the timeout advancing and abort being called
+ vi.advanceTimersByTime(2000);
+ if (signal.aborted) {
+ throw new DOMException('Aborted', 'AbortError');
+ }
+ // Should not reach here in a timeout scenario
+ return { ok: true, status: 200 };
+ },
+ );
+
+ const resultPromise = getEffectiveModel(apiKey, DEFAULT_GEMINI_MODEL);
+ // Ensure timers are advanced to trigger the timeout within getEffectiveModel
+ await vi.advanceTimersToNextTimerAsync(); // Or advanceTimersByTime(2000) if more precise control is needed
+
+ const result = await resultPromise;
+
+ expect(mockAbort).toHaveBeenCalledTimes(0); // setTimeout calls controller.abort(), not our direct mockAbort
+ expect(abortControllerInstance.abort).toHaveBeenCalledTimes(1);
+ expect(result).toEqual({
+ effectiveModel: DEFAULT_GEMINI_MODEL,
+ switched: false,
+ });
+ expect(fetch).toHaveBeenCalledTimes(1);
+ });
+
+ it('should correctly pass API key and model in the fetch request', async () => {
+ (fetch as vi.Mock).mockResolvedValueOnce({ ok: true, status: 200 });
+ const specificApiKey = 'specific-key-for-this-test';
+ await getEffectiveModel(specificApiKey, DEFAULT_GEMINI_MODEL);
+
+ expect(fetch).toHaveBeenCalledWith(
+ `https://generativelanguage.googleapis.com/v1beta/models/${DEFAULT_GEMINI_MODEL}:generateContent?key=${specificApiKey}`,
+ expect.objectContaining({
+ method: 'POST',
+ headers: { 'Content-Type': 'application/json' },
+ body: JSON.stringify({
+ contents: [{ parts: [{ text: 'test' }] }],
+ generationConfig: {
+ maxOutputTokens: 1,
+ temperature: 0,
+ topK: 1,
+ thinkingConfig: { thinkingBudget: 0, includeThoughts: false },
+ },
+ }),
+ }),
+ );
+ });
+ });
+});
diff --git a/packages/cli/src/utils/modelCheck.ts b/packages/cli/src/utils/modelCheck.ts
new file mode 100644
index 00000000..1634656e
--- /dev/null
+++ b/packages/cli/src/utils/modelCheck.ts
@@ -0,0 +1,75 @@
+/**
+ * @license
+ * Copyright 2025 Google LLC
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+import {
+ DEFAULT_GEMINI_MODEL,
+ DEFAULT_GEMINI_FLASH_MODEL,
+} from '../config/config.js';
+
+export interface EffectiveModelCheckResult {
+ effectiveModel: string;
+ switched: boolean;
+ originalModelIfSwitched?: string;
+}
+
+/**
+ * Checks if the default "pro" model is rate-limited and returns a fallback "flash"
+ * model if necessary. This function is designed to be silent.
+ * @param apiKey The API key to use for the check.
+ * @param currentConfiguredModel The model currently configured in settings.
+ * @returns An object indicating the model to use, whether a switch occurred,
+ * and the original model if a switch happened.
+ */
+export async function getEffectiveModel(
+ apiKey: string,
+ currentConfiguredModel: string,
+): Promise<EffectiveModelCheckResult> {
+ if (currentConfiguredModel !== DEFAULT_GEMINI_MODEL) {
+ // Only check if the user is trying to use the specific pro model we want to fallback from.
+ return { effectiveModel: currentConfiguredModel, switched: false };
+ }
+
+ const modelToTest = DEFAULT_GEMINI_MODEL;
+ const fallbackModel = DEFAULT_GEMINI_FLASH_MODEL;
+ const endpoint = `https://generativelanguage.googleapis.com/v1beta/models/${modelToTest}:generateContent?key=${apiKey}`;
+ const body = JSON.stringify({
+ contents: [{ parts: [{ text: 'test' }] }],
+ generationConfig: {
+ maxOutputTokens: 1,
+ temperature: 0,
+ topK: 1,
+ thinkingConfig: { thinkingBudget: 0, includeThoughts: false },
+ },
+ });
+
+ const controller = new AbortController();
+ const timeoutId = setTimeout(() => controller.abort(), 2000); // 500ms timeout for the request
+
+ try {
+ const response = await fetch(endpoint, {
+ method: 'POST',
+ headers: { 'Content-Type': 'application/json' },
+ body,
+ signal: controller.signal,
+ });
+
+ clearTimeout(timeoutId);
+
+ if (response.status === 429) {
+ return {
+ effectiveModel: fallbackModel,
+ switched: true,
+ originalModelIfSwitched: modelToTest,
+ };
+ }
+ // For any other case (success, other error codes), we stick to the original model.
+ return { effectiveModel: currentConfiguredModel, switched: false };
+ } catch (_error) {
+ clearTimeout(timeoutId);
+ // On timeout or any other fetch error, stick to the original model.
+ return { effectiveModel: currentConfiguredModel, switched: false };
+ }
+}