diff options
Diffstat (limited to 'packages/cli/src')
| -rw-r--r-- | packages/cli/src/config/auth.ts | 44 | ||||
| -rw-r--r-- | packages/cli/src/config/config.test.ts | 42 | ||||
| -rw-r--r-- | packages/cli/src/config/config.ts | 60 | ||||
| -rw-r--r-- | packages/cli/src/config/settings.ts | 2 | ||||
| -rw-r--r-- | packages/cli/src/gemini.tsx | 85 | ||||
| -rw-r--r-- | packages/cli/src/ui/App.test.tsx | 19 | ||||
| -rw-r--r-- | packages/cli/src/ui/App.tsx | 42 | ||||
| -rw-r--r-- | packages/cli/src/ui/components/AuthDialog.test.tsx | 41 | ||||
| -rw-r--r-- | packages/cli/src/ui/components/AuthDialog.tsx | 94 | ||||
| -rw-r--r-- | packages/cli/src/ui/hooks/slashCommandProcessor.test.ts | 3 | ||||
| -rw-r--r-- | packages/cli/src/ui/hooks/slashCommandProcessor.ts | 9 | ||||
| -rw-r--r-- | packages/cli/src/ui/hooks/useAuthCommand.ts | 57 | ||||
| -rw-r--r-- | packages/cli/src/ui/hooks/useGeminiStream.test.tsx | 1 | ||||
| -rw-r--r-- | packages/cli/src/ui/hooks/useGeminiStream.ts | 7 | ||||
| -rw-r--r-- | packages/cli/src/utils/modelCheck.test.ts | 154 | ||||
| -rw-r--r-- | packages/cli/src/utils/modelCheck.ts | 68 |
16 files changed, 375 insertions, 353 deletions
diff --git a/packages/cli/src/config/auth.ts b/packages/cli/src/config/auth.ts new file mode 100644 index 00000000..6153044e --- /dev/null +++ b/packages/cli/src/config/auth.ts @@ -0,0 +1,44 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { AuthType } from '@gemini-cli/core'; +import { loadEnvironment } from './config.js'; + +export const validateAuthMethod = (authMethod: string): string | null => { + loadEnvironment(); + if (authMethod === AuthType.LOGIN_WITH_GOOGLE_PERSONAL) { + return null; + } + + if (authMethod === AuthType.LOGIN_WITH_GOOGLE_ENTERPRISE) { + if (!process.env.GOOGLE_CLOUD_PROJECT) { + return 'GOOGLE_CLOUD_PROJECT environment variable not found. Add that to your .env and try again, no reload needed!'; + } + return null; + } + + if (authMethod === AuthType.USE_GEMINI) { + if (!process.env.GEMINI_API_KEY) { + return 'GEMINI_API_KEY environment variable not found. Add that to your .env and try again, no reload needed!'; + } + return null; + } + + if (authMethod === AuthType.USE_VERTEX_AI) { + if (!process.env.GOOGLE_API_KEY) { + return 'GOOGLE_API_KEY environment variable not found. Add that to your .env and try again, no reload needed!'; + } + if (!process.env.GOOGLE_CLOUD_PROJECT) { + return 'GOOGLE_CLOUD_PROJECT environment variable not found. Add that to your .env and try again, no reload needed!'; + } + if (!process.env.GOOGLE_CLOUD_LOCATION) { + return 'GOOGLE_CLOUD_LOCATION environment variable not found. Add that to your .env and try again, no reload needed!'; + } + return null; + } + + return 'Invalid auth method selected.'; +}; diff --git a/packages/cli/src/config/config.test.ts b/packages/cli/src/config/config.test.ts index 9ff75b15..5b24f434 100644 --- a/packages/cli/src/config/config.test.ts +++ b/packages/cli/src/config/config.test.ts @@ -247,48 +247,6 @@ describe('loadCliConfig telemetry', () => { }); }); -describe('API Key Handling', () => { - const originalEnv = { ...process.env }; - const originalArgv = process.argv; - - beforeEach(() => { - vi.resetAllMocks(); - process.argv = ['node', 'script.js']; - }); - - afterEach(() => { - process.env = originalEnv; - process.argv = originalArgv; - }); - - it('should use GEMINI_API_KEY from env', async () => { - process.env.GEMINI_API_KEY = 'gemini-key'; - delete process.env.GOOGLE_API_KEY; - - const settings: Settings = {}; - const result = await loadCliConfig(settings, [], 'test-session'); - expect(result.getContentGeneratorConfig().apiKey).toBe('gemini-key'); - }); - - it('should use GOOGLE_API_KEY and warn when both GOOGLE_API_KEY and GEMINI_API_KEY are set', async () => { - const consoleWarnSpy = vi - .spyOn(console, 'warn') - .mockImplementation(() => {}); - - process.env.GEMINI_API_KEY = 'gemini-key'; - process.env.GOOGLE_API_KEY = 'google-key'; - - const settings: Settings = {}; - const result = await loadCliConfig(settings, [], 'test-session'); - - expect(consoleWarnSpy).toHaveBeenCalledWith( - '[WARN]', - 'Both GEMINI_API_KEY and GOOGLE_API_KEY are set. Using GOOGLE_API_KEY.', - ); - expect(result.getContentGeneratorConfig().apiKey).toBe('google-key'); - }); -}); - describe('Hierarchical Memory Loading (config.ts) - Placeholder Suite', () => { beforeEach(() => { vi.resetAllMocks(); diff --git a/packages/cli/src/config/config.ts b/packages/cli/src/config/config.ts index 26878646..6e52c6e2 100644 --- a/packages/cli/src/config/config.ts +++ b/packages/cli/src/config/config.ts @@ -13,7 +13,6 @@ import { setGeminiMdFilename as setServerGeminiMdFilename, getCurrentGeminiMdFilename, ApprovalMode, - ContentGeneratorConfig, GEMINI_CONFIG_DIR as GEMINI_DIR, DEFAULT_GEMINI_MODEL, DEFAULT_GEMINI_EMBEDDING_MODEL, @@ -21,7 +20,7 @@ import { TelemetryTarget, } from '@gemini-cli/core'; import { Settings } from './settings.js'; -import { getEffectiveModel } from '../utils/modelCheck.js'; + import { Extension } from './extension.js'; import { getCliVersion } from '../utils/version.js'; import * as dotenv from 'dotenv'; @@ -194,15 +193,12 @@ export async function loadCliConfig( extensionContextFilePaths, ); - const contentGeneratorConfig = await createContentGeneratorConfig(argv); - const mcpServers = mergeMcpServers(settings, extensions); const sandboxConfig = await loadSandboxConfig(settings, argv); return new Config({ sessionId, - contentGeneratorConfig, embeddingModel: DEFAULT_GEMINI_EMBEDDING_MODEL, sandbox: sandboxConfig, targetDir: process.cwd(), @@ -242,6 +238,7 @@ export async function loadCliConfig( cwd: process.cwd(), fileDiscoveryService: fileService, bugCommand: settings.bugCommand, + model: argv.model!, }); } @@ -262,59 +259,6 @@ function mergeMcpServers(settings: Settings, extensions: Extension[]) { } return mcpServers; } - -async function createContentGeneratorConfig( - argv: CliArgs, -): Promise<ContentGeneratorConfig> { - const geminiApiKey = process.env.GEMINI_API_KEY; - const googleApiKey = process.env.GOOGLE_API_KEY; - const googleCloudProject = process.env.GOOGLE_CLOUD_PROJECT; - const googleCloudLocation = process.env.GOOGLE_CLOUD_LOCATION; - - const hasCodeAssist = process.env.GEMINI_CODE_ASSIST === 'true'; - const hasGeminiApiKey = !!geminiApiKey; - const hasGoogleApiKey = !!googleApiKey; - const hasVertexProjectLocationConfig = - !!googleCloudProject && !!googleCloudLocation; - - if (hasGeminiApiKey && hasGoogleApiKey) { - logger.warn( - 'Both GEMINI_API_KEY and GOOGLE_API_KEY are set. Using GOOGLE_API_KEY.', - ); - } - if ( - !hasCodeAssist && - !hasGeminiApiKey && - !hasGoogleApiKey && - !hasVertexProjectLocationConfig - ) { - logger.error( - 'No valid API authentication configuration found. Please set ONE of the following combinations in your environment variables or .env file:\n' + - '1. GEMINI_CODE_ASSIST=true (for Code Assist access).\n' + - '2. GEMINI_API_KEY (for Gemini API access).\n' + - '3. GOOGLE_API_KEY (for Gemini API or Vertex AI Express Mode access).\n' + - '4. GOOGLE_CLOUD_PROJECT and GOOGLE_CLOUD_LOCATION (for Vertex AI access).\n\n' + - 'For Gemini API keys, visit: https://ai.google.dev/gemini-api/docs/api-key\n' + - 'For Vertex AI authentication, visit: https://cloud.google.com/vertex-ai/docs/authentication\n' + - 'The GOOGLE_GENAI_USE_VERTEXAI environment variable can also be set to true/false to influence service selection when ambiguity exists.', - ); - process.exit(1); - } - - const config: ContentGeneratorConfig = { - model: argv.model || DEFAULT_GEMINI_MODEL, - apiKey: googleApiKey || geminiApiKey || '', - vertexai: hasGeminiApiKey ? false : undefined, - codeAssist: hasCodeAssist, - }; - - if (config.apiKey) { - config.model = await getEffectiveModel(config.apiKey, config.model); - } - - return config; -} - function findEnvFile(startDir: string): string | null { let currentDir = path.resolve(startDir); while (true) { diff --git a/packages/cli/src/config/settings.ts b/packages/cli/src/config/settings.ts index b17b4c9d..a90ed2d8 100644 --- a/packages/cli/src/config/settings.ts +++ b/packages/cli/src/config/settings.ts @@ -12,6 +12,7 @@ import { getErrorMessage, BugCommandSettings, TelemetrySettings, + AuthType, } from '@gemini-cli/core'; import stripJsonComments from 'strip-json-comments'; import { DefaultLight } from '../ui/themes/default-light.js'; @@ -32,6 +33,7 @@ export interface AccessibilitySettings { export interface Settings { theme?: string; + selectedAuthType?: AuthType; sandbox?: boolean | string; coreTools?: string[]; excludeTools?: string[]; diff --git a/packages/cli/src/gemini.tsx b/packages/cli/src/gemini.tsx index 7be84649..8dd52117 100644 --- a/packages/cli/src/gemini.tsx +++ b/packages/cli/src/gemini.tsx @@ -25,7 +25,9 @@ import { WriteFileTool, sessionId, logUserPrompt, + AuthType, } from '@gemini-cli/core'; +import { validateAuthMethod } from './config/auth.js'; export async function main() { const workspaceRoot = process.cwd(); @@ -47,10 +49,6 @@ export async function main() { const extensions = loadExtensions(workspaceRoot); const config = await loadCliConfig(settings.merged, extensions, sessionId); - // When using Code Assist this triggers the Oauth login. - // Do this now, before sandboxing, so web redirect works. - await config.getGeminiClient().initialize(); - // Initialize centralized FileDiscoveryService config.getFileService(); if (config.getCheckpointEnabled()) { @@ -73,6 +71,15 @@ export async function main() { if (!process.env.SANDBOX) { const sandboxConfig = config.getSandbox(); if (sandboxConfig) { + if (settings.merged.selectedAuthType) { + // Validate authentication here because the sandbox will interfere with the Oauth2 web redirect. + const err = validateAuthMethod(settings.merged.selectedAuthType); + if (err) { + console.error(err); + process.exit(1); + } + await config.refreshAuth(settings.merged.selectedAuthType); + } await start_sandbox(sandboxConfig); process.exit(0); } @@ -152,28 +159,58 @@ async function loadNonInteractiveConfig( extensions: Extension[], settings: LoadedSettings, ) { - if (config.getApprovalMode() === ApprovalMode.YOLO) { - // Since everything is being allowed we can use normal yolo behavior. - return config; - } + let finalConfig = config; + if (config.getApprovalMode() !== ApprovalMode.YOLO) { + // Everything is not allowed, ensure that only read-only tools are configured. + const existingExcludeTools = settings.merged.excludeTools || []; + const interactiveTools = [ + ShellTool.Name, + EditTool.Name, + WriteFileTool.Name, + ]; - // Everything is not allowed, ensure that only read-only tools are configured. - const existingExcludeTools = settings.merged.excludeTools || []; - const interactiveTools = [ShellTool.Name, EditTool.Name, WriteFileTool.Name]; + const newExcludeTools = [ + ...new Set([...existingExcludeTools, ...interactiveTools]), + ]; - const newExcludeTools = [ - ...new Set([...existingExcludeTools, ...interactiveTools]), - ]; + const nonInteractiveSettings = { + ...settings.merged, + excludeTools: newExcludeTools, + }; + finalConfig = await loadCliConfig( + nonInteractiveSettings, + extensions, + config.getSessionId(), + ); + } - const nonInteractiveSettings = { - ...settings.merged, - excludeTools: newExcludeTools, - }; - const newConfig = await loadCliConfig( - nonInteractiveSettings, - extensions, - config.getSessionId(), + return await validateNonInterActiveAuth( + settings.merged.selectedAuthType, + finalConfig, ); - await newConfig.getGeminiClient().initialize(); - return newConfig; +} + +async function validateNonInterActiveAuth( + selectedAuthType: AuthType | undefined, + nonInteractiveConfig: Config, +) { + // making a special case for the cli. many headless environments might not have a settings.json set + // so if GEMINI_API_KEY is set, we'll use that. However since the oauth things are interactive anyway, we'll + // still expect that exists + if (!selectedAuthType && !process.env.GEMINI_API_KEY) { + console.error( + 'Please set an Auth method in your .gemini/settings.json OR specify GEMINI_API_KEY env variable file before running', + ); + process.exit(1); + } + + selectedAuthType = selectedAuthType || AuthType.USE_GEMINI; + const err = validateAuthMethod(selectedAuthType); + if (err != null) { + console.error(err); + process.exit(1); + } + + await nonInteractiveConfig.refreshAuth(selectedAuthType); + return nonInteractiveConfig; } diff --git a/packages/cli/src/ui/App.test.tsx b/packages/cli/src/ui/App.test.tsx index 0ebaa34d..dca24b5c 100644 --- a/packages/cli/src/ui/App.test.tsx +++ b/packages/cli/src/ui/App.test.tsx @@ -145,6 +145,15 @@ vi.mock('./hooks/useGeminiStream', () => ({ })), })); +vi.mock('./hooks/useAuthCommand', () => ({ + useAuthCommand: vi.fn(() => ({ + isAuthDialogOpen: false, + openAuthDialog: vi.fn(), + handleAuthSelect: vi.fn(), + handleAuthHighlight: vi.fn(), + })), +})); + vi.mock('./hooks/useLogger', () => ({ useLogger: vi.fn(() => ({ getPreviousUserMessages: vi.fn().mockResolvedValue([]), @@ -176,7 +185,9 @@ describe('App UI', () => { }; const workspaceSettingsFile: SettingsFile = { path: '/workspace/.gemini/settings.json', - settings, + settings: { + ...settings, + }, }; return new LoadedSettings(userSettingsFile, workspaceSettingsFile, []); }; @@ -184,10 +195,6 @@ describe('App UI', () => { beforeEach(() => { const ServerConfigMocked = vi.mocked(ServerConfig, true); mockConfig = new ServerConfigMocked({ - contentGeneratorConfig: { - apiKey: 'test-key', - model: 'test-model', - }, embeddingModel: 'test-embedding-model', sandbox: undefined, targetDir: '/test/dir', @@ -197,7 +204,7 @@ describe('App UI', () => { showMemoryUsage: false, sessionId: 'test-session-id', cwd: '/tmp', - // Provide other required fields for ConfigParameters if necessary + model: 'model', }) as unknown as MockServerConfig; // Ensure the getShowMemoryUsage mock function is specifically set up if not covered by constructor mock diff --git a/packages/cli/src/ui/App.tsx b/packages/cli/src/ui/App.tsx index a9c5f0e7..c481ebd3 100644 --- a/packages/cli/src/ui/App.tsx +++ b/packages/cli/src/ui/App.tsx @@ -20,6 +20,7 @@ import { useTerminalSize } from './hooks/useTerminalSize.js'; import { useGeminiStream } from './hooks/useGeminiStream.js'; import { useLoadingIndicator } from './hooks/useLoadingIndicator.js'; import { useThemeCommand } from './hooks/useThemeCommand.js'; +import { useAuthCommand } from './hooks/useAuthCommand.js'; import { useEditorSettings } from './hooks/useEditorSettings.js'; import { useSlashCommandProcessor } from './hooks/slashCommandProcessor.js'; import { useAutoAcceptIndicator } from './hooks/useAutoAcceptIndicator.js'; @@ -31,6 +32,7 @@ import { ShellModeIndicator } from './components/ShellModeIndicator.js'; import { InputPrompt } from './components/InputPrompt.js'; import { Footer } from './components/Footer.js'; import { ThemeDialog } from './components/ThemeDialog.js'; +import { AuthDialog } from './components/AuthDialog.js'; import { EditorSettingsDialog } from './components/EditorSettingsDialog.js'; import { Colors } from './colors.js'; import { Help } from './components/Help.js'; @@ -51,6 +53,7 @@ import { isEditorAvailable, EditorType, } from '@gemini-cli/core'; +import { validateAuthMethod } from '../config/auth.js'; import { useLogger } from './hooks/useLogger.js'; import { StreamingContext } from './contexts/StreamingContext.js'; import { @@ -101,6 +104,7 @@ const App = ({ config, settings, startupWarnings = [] }: AppProps) => { const [debugMessage, setDebugMessage] = useState<string>(''); const [showHelp, setShowHelp] = useState<boolean>(false); const [themeError, setThemeError] = useState<string | null>(null); + const [authError, setAuthError] = useState<string | null>(null); const [editorError, setEditorError] = useState<string | null>(null); const [footerHeight, setFooterHeight] = useState<number>(0); const [corgiMode, setCorgiMode] = useState(false); @@ -130,6 +134,23 @@ const App = ({ config, settings, startupWarnings = [] }: AppProps) => { } = useThemeCommand(settings, setThemeError, addItem); const { + isAuthDialogOpen, + openAuthDialog, + handleAuthSelect, + handleAuthHighlight, + } = useAuthCommand(settings, setAuthError, config); + + useEffect(() => { + if (settings.merged.selectedAuthType) { + const error = validateAuthMethod(settings.merged.selectedAuthType); + if (error) { + setAuthError(error); + openAuthDialog(); + } + } + }, [settings.merged.selectedAuthType, openAuthDialog, setAuthError]); + + const { isEditorDialogOpen, openEditorDialog, handleEditorSelect, @@ -197,6 +218,7 @@ const App = ({ config, settings, startupWarnings = [] }: AppProps) => { setShowHelp, setDebugMessage, openThemeDialog, + openAuthDialog, openEditorDialog, performMemoryRefresh, toggleCorgiMode, @@ -306,6 +328,11 @@ const App = ({ config, settings, startupWarnings = [] }: AppProps) => { return editorType as EditorType; }, [settings, openEditorDialog]); + const onAuthError = useCallback(() => { + setAuthError('reauth required'); + openAuthDialog(); + }, [openAuthDialog, setAuthError]); + const { streamingState, submitQuery, @@ -322,6 +349,7 @@ const App = ({ config, settings, startupWarnings = [] }: AppProps) => { handleSlashCommand, shellModeActive, getPreferredEditor, + onAuthError, ); pendingHistoryItems.push(...pendingGeminiHistoryItems); const { elapsedTime, currentLoadingPhrase } = @@ -557,6 +585,20 @@ const App = ({ config, settings, startupWarnings = [] }: AppProps) => { terminalWidth={mainAreaWidth} /> </Box> + ) : isAuthDialogOpen ? ( + <Box flexDirection="column"> + {authError && ( + <Box marginBottom={1}> + <Text color={Colors.AccentRed}>{authError}</Text> + </Box> + )} + <AuthDialog + onSelect={handleAuthSelect} + onHighlight={handleAuthHighlight} + settings={settings} + initialErrorMessage={authError} + /> + </Box> ) : isEditorDialogOpen ? ( <Box flexDirection="column"> {editorError && ( diff --git a/packages/cli/src/ui/components/AuthDialog.test.tsx b/packages/cli/src/ui/components/AuthDialog.test.tsx new file mode 100644 index 00000000..a5f46d93 --- /dev/null +++ b/packages/cli/src/ui/components/AuthDialog.test.tsx @@ -0,0 +1,41 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { render } from 'ink-testing-library'; +import { AuthDialog } from './AuthDialog.js'; +import { LoadedSettings } from '../../config/settings.js'; +import { AuthType } from '@gemini-cli/core'; + +describe('AuthDialog', () => { + it('should show an error if the initial auth type is invalid', () => { + const settings: LoadedSettings = new LoadedSettings( + { + settings: { + selectedAuthType: AuthType.USE_GEMINI, + }, + path: '', + }, + { + settings: {}, + path: '', + }, + [], + ); + + const { lastFrame } = render( + <AuthDialog + onSelect={() => {}} + onHighlight={() => {}} + settings={settings} + initialErrorMessage="GEMINI_API_KEY environment variable not found" + />, + ); + + expect(lastFrame()).toContain( + 'GEMINI_API_KEY environment variable not found', + ); + }); +}); diff --git a/packages/cli/src/ui/components/AuthDialog.tsx b/packages/cli/src/ui/components/AuthDialog.tsx new file mode 100644 index 00000000..b16529cf --- /dev/null +++ b/packages/cli/src/ui/components/AuthDialog.tsx @@ -0,0 +1,94 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import React, { useState } from 'react'; +import { Box, Text, useInput } from 'ink'; +import { Colors } from '../colors.js'; +import { RadioButtonSelect } from './shared/RadioButtonSelect.js'; +import { LoadedSettings, SettingScope } from '../../config/settings.js'; +import { AuthType } from '@gemini-cli/core'; +import { validateAuthMethod } from '../../config/auth.js'; + +interface AuthDialogProps { + onSelect: (authMethod: string | undefined, scope: SettingScope) => void; + onHighlight: (authMethod: string | undefined) => void; + settings: LoadedSettings; + initialErrorMessage?: string | null; +} + +export function AuthDialog({ + onSelect, + onHighlight, + settings, + initialErrorMessage, +}: AuthDialogProps): React.JSX.Element { + const [errorMessage, setErrorMessage] = useState<string | null>( + initialErrorMessage || null, + ); + const authItems = [ + { + label: 'Login with Google Personal Account', + value: AuthType.LOGIN_WITH_GOOGLE_PERSONAL, + }, + { label: 'Gemini API Key', value: AuthType.USE_GEMINI }, + { + label: 'Login with GCP Project and Google Work Account', + value: AuthType.LOGIN_WITH_GOOGLE_ENTERPRISE, + }, + { label: 'Vertex AI', value: AuthType.USE_VERTEX_AI }, + ]; + + let initialAuthIndex = authItems.findIndex( + (item) => item.value === settings.merged.selectedAuthType, + ); + + if (initialAuthIndex === -1) { + initialAuthIndex = 0; + } + + const handleAuthSelect = (authMethod: string) => { + const error = validateAuthMethod(authMethod); + if (error) { + setErrorMessage(error); + } else { + setErrorMessage(null); + onSelect(authMethod, SettingScope.User); + } + }; + + useInput((_input, key) => { + if (key.escape) { + onSelect(undefined, SettingScope.User); + } + }); + + return ( + <Box + borderStyle="round" + borderColor={Colors.Gray} + flexDirection="column" + padding={1} + width="100%" + > + <Text bold>Select Auth Method</Text> + <RadioButtonSelect + items={authItems} + initialIndex={initialAuthIndex} + onSelect={handleAuthSelect} + onHighlight={onHighlight} + isFocused={true} + /> + {errorMessage && ( + <Box marginTop={1}> + <Text color={Colors.AccentRed}>{errorMessage}</Text> + </Box> + )} + <Box marginTop={1}> + <Text color={Colors.Gray}>(Use Enter to select)</Text> + </Box> + </Box> + ); +} diff --git a/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts b/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts index 7c750af1..04931c7f 100644 --- a/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts +++ b/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts @@ -103,6 +103,7 @@ describe('useSlashCommandProcessor', () => { let mockSetShowHelp: ReturnType<typeof vi.fn>; let mockOnDebugMessage: ReturnType<typeof vi.fn>; let mockOpenThemeDialog: ReturnType<typeof vi.fn>; + let mockOpenAuthDialog: ReturnType<typeof vi.fn>; let mockOpenEditorDialog: ReturnType<typeof vi.fn>; let mockPerformMemoryRefresh: ReturnType<typeof vi.fn>; let mockSetQuittingMessages: ReturnType<typeof vi.fn>; @@ -120,6 +121,7 @@ describe('useSlashCommandProcessor', () => { mockSetShowHelp = vi.fn(); mockOnDebugMessage = vi.fn(); mockOpenThemeDialog = vi.fn(); + mockOpenAuthDialog = vi.fn(); mockOpenEditorDialog = vi.fn(); mockPerformMemoryRefresh = vi.fn().mockResolvedValue(undefined); mockSetQuittingMessages = vi.fn(); @@ -171,6 +173,7 @@ describe('useSlashCommandProcessor', () => { mockSetShowHelp, mockOnDebugMessage, mockOpenThemeDialog, + mockOpenAuthDialog, mockOpenEditorDialog, mockPerformMemoryRefresh, mockCorgiMode, diff --git a/packages/cli/src/ui/hooks/slashCommandProcessor.ts b/packages/cli/src/ui/hooks/slashCommandProcessor.ts index 0e622f23..ee7b55cb 100644 --- a/packages/cli/src/ui/hooks/slashCommandProcessor.ts +++ b/packages/cli/src/ui/hooks/slashCommandProcessor.ts @@ -68,6 +68,7 @@ export const useSlashCommandProcessor = ( setShowHelp: React.Dispatch<React.SetStateAction<boolean>>, onDebugMessage: (message: string) => void, openThemeDialog: () => void, + openAuthDialog: () => void, openEditorDialog: () => void, performMemoryRefresh: () => Promise<void>, toggleCorgiMode: () => void, @@ -198,6 +199,13 @@ export const useSlashCommandProcessor = ( }, }, { + name: 'auth', + description: 'change the auth method', + action: (_mainCommand, _subCommand, _args) => { + openAuthDialog(); + }, + }, + { name: 'editor', description: 'set external editor preference', action: (_mainCommand, _subCommand, _args) => { @@ -907,6 +915,7 @@ Add any other context about the problem here. setShowHelp, refreshStatic, openThemeDialog, + openAuthDialog, openEditorDialog, clearItems, performMemoryRefresh, diff --git a/packages/cli/src/ui/hooks/useAuthCommand.ts b/packages/cli/src/ui/hooks/useAuthCommand.ts new file mode 100644 index 00000000..a9b1cb1e --- /dev/null +++ b/packages/cli/src/ui/hooks/useAuthCommand.ts @@ -0,0 +1,57 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { useState, useCallback, useEffect } from 'react'; +import { LoadedSettings, SettingScope } from '../../config/settings.js'; +import { AuthType, Config, clearCachedCredentialFile } from '@gemini-cli/core'; + +async function performAuthFlow(authMethod: AuthType, config: Config) { + await config.refreshAuth(authMethod); + console.log(`Authenticated via "${authMethod}".`); +} + +export const useAuthCommand = ( + settings: LoadedSettings, + setAuthError: (error: string | null) => void, + config: Config, +) => { + const [isAuthDialogOpen, setIsAuthDialogOpen] = useState( + settings.merged.selectedAuthType === undefined, + ); + + useEffect(() => { + if (!isAuthDialogOpen) { + performAuthFlow(settings.merged.selectedAuthType as AuthType, config); + } + }, [isAuthDialogOpen, settings, config]); + + const openAuthDialog = useCallback(() => { + setIsAuthDialogOpen(true); + }, []); + + const handleAuthSelect = useCallback( + async (authMethod: string | undefined, scope: SettingScope) => { + if (authMethod) { + await clearCachedCredentialFile(); + settings.setValue(scope, 'selectedAuthType', authMethod); + } + setIsAuthDialogOpen(false); + setAuthError(null); + }, + [settings, setAuthError], + ); + + const handleAuthHighlight = useCallback((_authMethod: string | undefined) => { + // For now, we don't do anything on highlight. + }, []); + + return { + isAuthDialogOpen, + openAuthDialog, + handleAuthSelect, + handleAuthHighlight, + }; +}; diff --git a/packages/cli/src/ui/hooks/useGeminiStream.test.tsx b/packages/cli/src/ui/hooks/useGeminiStream.test.tsx index 96dd6aef..36f420e4 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.test.tsx +++ b/packages/cli/src/ui/hooks/useGeminiStream.test.tsx @@ -359,6 +359,7 @@ describe('useGeminiStream', () => { props.handleSlashCommand, props.shellModeActive, () => 'vscode' as EditorType, + () => {}, ), { initialProps: { diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index 6d92af0d..4049c884 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -22,6 +22,7 @@ import { GitService, EditorType, ThoughtSummary, + isAuthError, } from '@gemini-cli/core'; import { type Part, type PartListUnion } from '@google/genai'; import { @@ -87,6 +88,7 @@ export const useGeminiStream = ( >, shellModeActive: boolean, getPreferredEditor: () => EditorType | undefined, + onAuthError: () => void, ) => { const [initError, setInitError] = useState<string | null>(null); const abortControllerRef = useRef<AbortController | null>(null); @@ -496,7 +498,9 @@ export const useGeminiStream = ( setPendingHistoryItem(null); } } catch (error: unknown) { - if (!isNodeError(error) || error.name !== 'AbortError') { + if (isAuthError(error)) { + onAuthError(); + } else if (!isNodeError(error) || error.name !== 'AbortError') { addItem( { type: MessageType.ERROR, @@ -522,6 +526,7 @@ export const useGeminiStream = ( setInitError, geminiClient, startNewTurn, + onAuthError, ], ); diff --git a/packages/cli/src/utils/modelCheck.test.ts b/packages/cli/src/utils/modelCheck.test.ts deleted file mode 100644 index 11e38c01..00000000 --- a/packages/cli/src/utils/modelCheck.test.ts +++ /dev/null @@ -1,154 +0,0 @@ -/** - * @license - * Copyright 2025 Google LLC - * SPDX-License-Identifier: Apache-2.0 - */ - -import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; -import { getEffectiveModel } from './modelCheck.js'; -import { - DEFAULT_GEMINI_MODEL, - DEFAULT_GEMINI_FLASH_MODEL, -} from '@gemini-cli/core'; - -// Mock global fetch -global.fetch = vi.fn(); - -// Mock AbortController -const mockAbort = vi.fn(); -global.AbortController = vi.fn(() => ({ - signal: { aborted: false }, // Start with not aborted - abort: mockAbort, - // eslint-disable-next-line @typescript-eslint/no-explicit-any -})) as any; - -describe('getEffectiveModel', () => { - const apiKey = 'test-api-key'; - - beforeEach(() => { - vi.useFakeTimers(); - vi.clearAllMocks(); - // Reset signal for each test if AbortController mock is more complex - global.AbortController = vi.fn(() => ({ - signal: { aborted: false }, - abort: mockAbort, - // eslint-disable-next-line @typescript-eslint/no-explicit-any - })) as any; - }); - - afterEach(() => { - vi.restoreAllMocks(); - vi.useRealTimers(); - }); - - describe('when currentConfiguredModel is not DEFAULT_GEMINI_MODEL', () => { - it('should return the currentConfiguredModel without fetching', async () => { - const customModel = 'custom-model-name'; - const result = await getEffectiveModel(apiKey, customModel); - expect(result).toEqual(customModel); - expect(fetch).not.toHaveBeenCalled(); - }); - }); - - describe('when currentConfiguredModel is DEFAULT_GEMINI_MODEL', () => { - it('should switch to DEFAULT_GEMINI_FLASH_MODEL if fetch returns 429', async () => { - (fetch as vi.Mock).mockResolvedValueOnce({ - ok: false, - status: 429, - }); - const result = await getEffectiveModel(apiKey, DEFAULT_GEMINI_MODEL); - expect(result).toEqual(DEFAULT_GEMINI_FLASH_MODEL); - expect(fetch).toHaveBeenCalledTimes(1); - expect(fetch).toHaveBeenCalledWith( - `https://generativelanguage.googleapis.com/v1beta/models/${DEFAULT_GEMINI_MODEL}:generateContent?key=${apiKey}`, - expect.any(Object), - ); - }); - - it('should return DEFAULT_GEMINI_MODEL if fetch returns 200', async () => { - (fetch as vi.Mock).mockResolvedValueOnce({ - ok: true, - status: 200, - }); - const result = await getEffectiveModel(apiKey, DEFAULT_GEMINI_MODEL); - expect(result).toEqual(DEFAULT_GEMINI_MODEL); - expect(fetch).toHaveBeenCalledTimes(1); - }); - - it('should return DEFAULT_GEMINI_MODEL if fetch returns a non-429 error status (e.g., 500)', async () => { - (fetch as vi.Mock).mockResolvedValueOnce({ - ok: false, - status: 500, - }); - const result = await getEffectiveModel(apiKey, DEFAULT_GEMINI_MODEL); - expect(result).toEqual(DEFAULT_GEMINI_MODEL); - expect(fetch).toHaveBeenCalledTimes(1); - }); - - it('should return DEFAULT_GEMINI_MODEL if fetch throws a network error', async () => { - (fetch as vi.Mock).mockRejectedValueOnce(new Error('Network error')); - const result = await getEffectiveModel(apiKey, DEFAULT_GEMINI_MODEL); - expect(result).toEqual(DEFAULT_GEMINI_MODEL); - expect(fetch).toHaveBeenCalledTimes(1); - }); - - it('should return DEFAULT_GEMINI_MODEL if fetch times out', async () => { - // Simulate AbortController's signal changing and fetch throwing AbortError - const abortControllerInstance = { - signal: { aborted: false }, // mutable signal - abort: vi.fn(() => { - abortControllerInstance.signal.aborted = true; // Use abortControllerInstance - }), - }; - (global.AbortController as vi.Mock).mockImplementationOnce( - () => abortControllerInstance, - ); - - (fetch as vi.Mock).mockImplementationOnce( - async ({ signal }: { signal: AbortSignal }) => { - // Simulate the timeout advancing and abort being called - vi.advanceTimersByTime(2000); - if (signal.aborted) { - throw new DOMException('Aborted', 'AbortError'); - } - // Should not reach here in a timeout scenario - return { ok: true, status: 200 }; - }, - ); - - const resultPromise = getEffectiveModel(apiKey, DEFAULT_GEMINI_MODEL); - // Ensure timers are advanced to trigger the timeout within getEffectiveModel - await vi.advanceTimersToNextTimerAsync(); // Or advanceTimersByTime(2000) if more precise control is needed - - const result = await resultPromise; - - expect(mockAbort).toHaveBeenCalledTimes(0); // setTimeout calls controller.abort(), not our direct mockAbort - expect(abortControllerInstance.abort).toHaveBeenCalledTimes(1); - expect(result).toEqual(DEFAULT_GEMINI_MODEL); - expect(fetch).toHaveBeenCalledTimes(1); - }); - - it('should correctly pass API key and model in the fetch request', async () => { - (fetch as vi.Mock).mockResolvedValueOnce({ ok: true, status: 200 }); - const specificApiKey = 'specific-key-for-this-test'; - await getEffectiveModel(specificApiKey, DEFAULT_GEMINI_MODEL); - - expect(fetch).toHaveBeenCalledWith( - `https://generativelanguage.googleapis.com/v1beta/models/${DEFAULT_GEMINI_MODEL}:generateContent?key=${specificApiKey}`, - expect.objectContaining({ - method: 'POST', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ - contents: [{ parts: [{ text: 'test' }] }], - generationConfig: { - maxOutputTokens: 1, - temperature: 0, - topK: 1, - thinkingConfig: { thinkingBudget: 0, includeThoughts: false }, - }, - }), - }), - ); - }); - }); -}); diff --git a/packages/cli/src/utils/modelCheck.ts b/packages/cli/src/utils/modelCheck.ts deleted file mode 100644 index 7d7a3b7d..00000000 --- a/packages/cli/src/utils/modelCheck.ts +++ /dev/null @@ -1,68 +0,0 @@ -/** - * @license - * Copyright 2025 Google LLC - * SPDX-License-Identifier: Apache-2.0 - */ - -import { - DEFAULT_GEMINI_MODEL, - DEFAULT_GEMINI_FLASH_MODEL, -} from '@gemini-cli/core'; - -/** - * Checks if the default "pro" model is rate-limited and returns a fallback "flash" - * model if necessary. This function is designed to be silent. - * @param apiKey The API key to use for the check. - * @param currentConfiguredModel The model currently configured in settings. - * @returns An object indicating the model to use, whether a switch occurred, - * and the original model if a switch happened. - */ -export async function getEffectiveModel( - apiKey: string, - currentConfiguredModel: string, -): Promise<string> { - if (currentConfiguredModel !== DEFAULT_GEMINI_MODEL) { - // Only check if the user is trying to use the specific pro model we want to fallback from. - return currentConfiguredModel; - } - - const modelToTest = DEFAULT_GEMINI_MODEL; - const fallbackModel = DEFAULT_GEMINI_FLASH_MODEL; - const endpoint = `https://generativelanguage.googleapis.com/v1beta/models/${modelToTest}:generateContent?key=${apiKey}`; - const body = JSON.stringify({ - contents: [{ parts: [{ text: 'test' }] }], - generationConfig: { - maxOutputTokens: 1, - temperature: 0, - topK: 1, - thinkingConfig: { thinkingBudget: 0, includeThoughts: false }, - }, - }); - - const controller = new AbortController(); - const timeoutId = setTimeout(() => controller.abort(), 2000); // 500ms timeout for the request - - try { - const response = await fetch(endpoint, { - method: 'POST', - headers: { 'Content-Type': 'application/json' }, - body, - signal: controller.signal, - }); - - clearTimeout(timeoutId); - - if (response.status === 429) { - console.log( - `[INFO] Your configured model (${modelToTest}) was temporarily unavailable. Switched to ${fallbackModel} for this session.`, - ); - return fallbackModel; - } - // For any other case (success, other error codes), we stick to the original model. - return currentConfiguredModel; - } catch (_error) { - clearTimeout(timeoutId); - // On timeout or any other fetch error, stick to the original model. - return currentConfiguredModel; - } -} |
