diff options
Diffstat (limited to 'packages/core/src')
| -rw-r--r-- | packages/core/src/code_assist/setup.test.ts | 82 | ||||
| -rw-r--r-- | packages/core/src/code_assist/setup.ts | 2 | ||||
| -rw-r--r-- | packages/core/src/core/contentGenerator.test.ts | 79 | ||||
| -rw-r--r-- | packages/core/src/core/contentGenerator.ts | 17 |
4 files changed, 165 insertions, 15 deletions
diff --git a/packages/core/src/code_assist/setup.test.ts b/packages/core/src/code_assist/setup.test.ts new file mode 100644 index 00000000..479abae0 --- /dev/null +++ b/packages/core/src/code_assist/setup.test.ts @@ -0,0 +1,82 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { setupUser, ProjectIdRequiredError } from './setup.js'; +import { CodeAssistServer } from '../code_assist/server.js'; +import { OAuth2Client } from 'google-auth-library'; +import { GeminiUserTier, UserTierId } from './types.js'; + +vi.mock('../code_assist/server.js'); + +const mockPaidTier: GeminiUserTier = { + id: UserTierId.STANDARD, + name: 'paid', + description: 'Paid tier', +}; + +describe('setupUser', () => { + let mockLoad: ReturnType<typeof vi.fn>; + let mockOnboardUser: ReturnType<typeof vi.fn>; + + beforeEach(() => { + vi.resetAllMocks(); + mockLoad = vi.fn(); + mockOnboardUser = vi.fn().mockResolvedValue({ + done: true, + response: { + cloudaicompanionProject: { + id: 'server-project', + }, + }, + }); + vi.mocked(CodeAssistServer).mockImplementation( + () => + ({ + loadCodeAssist: mockLoad, + onboardUser: mockOnboardUser, + }) as unknown as CodeAssistServer, + ); + }); + + it('should use GOOGLE_CLOUD_PROJECT when set', async () => { + process.env.GOOGLE_CLOUD_PROJECT = 'test-project'; + mockLoad.mockResolvedValue({ + currentTier: mockPaidTier, + }); + await setupUser({} as OAuth2Client); + expect(CodeAssistServer).toHaveBeenCalledWith( + expect.any(Object), + 'test-project', + ); + }); + + it('should treat empty GOOGLE_CLOUD_PROJECT as undefined and use project from server', async () => { + process.env.GOOGLE_CLOUD_PROJECT = ''; + mockLoad.mockResolvedValue({ + cloudaicompanionProject: 'server-project', + currentTier: mockPaidTier, + }); + const projectId = await setupUser({} as OAuth2Client); + expect(CodeAssistServer).toHaveBeenCalledWith( + expect.any(Object), + undefined, + ); + expect(projectId).toBe('server-project'); + }); + + it('should throw ProjectIdRequiredError when no project ID is available', async () => { + delete process.env.GOOGLE_CLOUD_PROJECT; + // And the server itself requires a project ID internally + vi.mocked(CodeAssistServer).mockImplementation(() => { + throw new ProjectIdRequiredError(); + }); + + await expect(setupUser({} as OAuth2Client)).rejects.toThrow( + ProjectIdRequiredError, + ); + }); +}); diff --git a/packages/core/src/code_assist/setup.ts b/packages/core/src/code_assist/setup.ts index 7db6bdcd..3c7b81b0 100644 --- a/packages/core/src/code_assist/setup.ts +++ b/packages/core/src/code_assist/setup.ts @@ -28,7 +28,7 @@ export class ProjectIdRequiredError extends Error { * @returns the user's actual project id */ export async function setupUser(client: OAuth2Client): Promise<string> { - let projectId = process.env.GOOGLE_CLOUD_PROJECT; + let projectId = process.env.GOOGLE_CLOUD_PROJECT || undefined; const caServer = new CodeAssistServer(client, projectId); const clientMetadata: ClientMetadata = { diff --git a/packages/core/src/core/contentGenerator.test.ts b/packages/core/src/core/contentGenerator.test.ts index 4c6134f2..eb480710 100644 --- a/packages/core/src/core/contentGenerator.test.ts +++ b/packages/core/src/core/contentGenerator.test.ts @@ -4,15 +4,19 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { describe, it, expect, vi } from 'vitest'; -import { createContentGenerator, AuthType } from './contentGenerator.js'; +import { describe, it, expect, vi, beforeEach, afterAll } from 'vitest'; +import { + createContentGenerator, + AuthType, + createContentGeneratorConfig, +} from './contentGenerator.js'; import { createCodeAssistContentGenerator } from '../code_assist/codeAssist.js'; import { GoogleGenAI } from '@google/genai'; vi.mock('../code_assist/codeAssist.js'); vi.mock('@google/genai'); -describe('contentGenerator', () => { +describe('createContentGenerator', () => { it('should create a CodeAssistContentGenerator', async () => { const mockGenerator = {} as unknown; vi.mocked(createCodeAssistContentGenerator).mockResolvedValue( @@ -48,3 +52,72 @@ describe('contentGenerator', () => { expect(generator).toBe((mockGenerator as GoogleGenAI).models); }); }); + +describe('createContentGeneratorConfig', () => { + const originalEnv = process.env; + + beforeEach(() => { + // Reset modules to re-evaluate imports and environment variables + vi.resetModules(); + // Restore process.env before each test + process.env = { ...originalEnv }; + }); + + afterAll(() => { + // Restore original process.env after all tests + process.env = originalEnv; + }); + + it('should configure for Gemini using GEMINI_API_KEY when set', async () => { + process.env.GEMINI_API_KEY = 'env-gemini-key'; + const config = await createContentGeneratorConfig( + undefined, + AuthType.USE_GEMINI, + ); + expect(config.apiKey).toBe('env-gemini-key'); + expect(config.vertexai).toBe(false); + }); + + it('should not configure for Gemini if GEMINI_API_KEY is empty', async () => { + process.env.GEMINI_API_KEY = ''; + const config = await createContentGeneratorConfig( + undefined, + AuthType.USE_GEMINI, + ); + expect(config.apiKey).toBeUndefined(); + expect(config.vertexai).toBeUndefined(); + }); + + it('should configure for Vertex AI using GOOGLE_API_KEY when set', async () => { + process.env.GOOGLE_API_KEY = 'env-google-key'; + const config = await createContentGeneratorConfig( + undefined, + AuthType.USE_VERTEX_AI, + ); + expect(config.apiKey).toBe('env-google-key'); + expect(config.vertexai).toBe(true); + }); + + it('should configure for Vertex AI using GCP project and location when set', async () => { + process.env.GOOGLE_CLOUD_PROJECT = 'env-gcp-project'; + process.env.GOOGLE_CLOUD_LOCATION = 'env-gcp-location'; + const config = await createContentGeneratorConfig( + undefined, + AuthType.USE_VERTEX_AI, + ); + expect(config.vertexai).toBe(true); + expect(config.apiKey).toBeUndefined(); + }); + + it('should not configure for Vertex AI if required env vars are empty', async () => { + process.env.GOOGLE_API_KEY = ''; + process.env.GOOGLE_CLOUD_PROJECT = ''; + process.env.GOOGLE_CLOUD_LOCATION = ''; + const config = await createContentGeneratorConfig( + undefined, + AuthType.USE_VERTEX_AI, + ); + expect(config.apiKey).toBeUndefined(); + expect(config.vertexai).toBeUndefined(); + }); +}); diff --git a/packages/core/src/core/contentGenerator.ts b/packages/core/src/core/contentGenerator.ts index ce3c11a9..e9e1138f 100644 --- a/packages/core/src/core/contentGenerator.ts +++ b/packages/core/src/core/contentGenerator.ts @@ -52,10 +52,10 @@ export async function createContentGeneratorConfig( model: string | undefined, authType: AuthType | undefined, ): Promise<ContentGeneratorConfig> { - const geminiApiKey = process.env.GEMINI_API_KEY; - const googleApiKey = process.env.GOOGLE_API_KEY; - const googleCloudProject = process.env.GOOGLE_CLOUD_PROJECT; - const googleCloudLocation = process.env.GOOGLE_CLOUD_LOCATION; + const geminiApiKey = process.env.GEMINI_API_KEY || undefined; + const googleApiKey = process.env.GOOGLE_API_KEY || undefined; + const googleCloudProject = process.env.GOOGLE_CLOUD_PROJECT || undefined; + const googleCloudLocation = process.env.GOOGLE_CLOUD_LOCATION || undefined; // Use runtime model from config if available, otherwise fallback to parameter or default const effectiveModel = model || DEFAULT_GEMINI_MODEL; @@ -75,6 +75,7 @@ export async function createContentGeneratorConfig( if (authType === AuthType.USE_GEMINI && geminiApiKey) { contentGeneratorConfig.apiKey = geminiApiKey; + contentGeneratorConfig.vertexai = false; contentGeneratorConfig.model = await getEffectiveModel( contentGeneratorConfig.apiKey, contentGeneratorConfig.model, @@ -85,16 +86,10 @@ export async function createContentGeneratorConfig( if ( authType === AuthType.USE_VERTEX_AI && - !!googleApiKey && - googleCloudProject && - googleCloudLocation + (googleApiKey || (googleCloudProject && googleCloudLocation)) ) { contentGeneratorConfig.apiKey = googleApiKey; contentGeneratorConfig.vertexai = true; - contentGeneratorConfig.model = await getEffectiveModel( - contentGeneratorConfig.apiKey, - contentGeneratorConfig.model, - ); return contentGeneratorConfig; } |
