diff options
Diffstat (limited to 'packages/core/src')
| -rw-r--r-- | packages/core/src/config/config.test.ts | 1 | ||||
| -rw-r--r-- | packages/core/src/config/config.ts | 56 | ||||
| -rw-r--r-- | packages/core/src/core/client.test.ts | 70 | ||||
| -rw-r--r-- | packages/core/src/core/client.ts | 17 | ||||
| -rw-r--r-- | packages/core/src/core/geminiChat.test.ts | 69 | ||||
| -rw-r--r-- | packages/core/src/core/geminiRequest.test.ts | 85 | ||||
| -rw-r--r-- | packages/core/src/tools/tool-registry.test.ts | 1 |
7 files changed, 246 insertions, 53 deletions
diff --git a/packages/core/src/config/config.test.ts b/packages/core/src/config/config.test.ts index 2827f581..71af832b 100644 --- a/packages/core/src/config/config.test.ts +++ b/packages/core/src/config/config.test.ts @@ -51,6 +51,7 @@ describe('Server Config (config.ts)', () => { const EMBEDDING_MODEL = 'gemini-embedding'; const SESSION_ID = 'test-session-id'; const baseParams: ConfigParameters = { + cwd: '/tmp', contentGeneratorConfig: { apiKey: API_KEY, model: MODEL, diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index b94a88a4..3bb9815f 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -4,11 +4,8 @@ * SPDX-License-Identifier: Apache-2.0 */ -import * as dotenv from 'dotenv'; -import * as fs from 'node:fs'; import * as path from 'node:path'; import process from 'node:process'; -import * as os from 'node:os'; import { ContentGeneratorConfig } from '../core/contentGenerator.js'; import { ToolRegistry } from '../tools/tool-registry.js'; import { LSTool } from '../tools/ls.js'; @@ -79,9 +76,12 @@ export interface ConfigParameters { accessibility?: AccessibilitySettings; telemetry?: boolean; telemetryLogUserPromptsEnabled?: boolean; + telemetryOtlpEndpoint?: string; fileFilteringRespectGitIgnore?: boolean; fileFilteringAllowBuildArtifacts?: boolean; checkpoint?: boolean; + proxy?: string; + cwd: string; } export class Config { @@ -115,6 +115,8 @@ export class Config { private fileDiscoveryService: FileDiscoveryService | null = null; private gitService: GitService | undefined = undefined; private readonly checkpoint: boolean; + private readonly proxy: string | undefined; + private readonly cwd: string; constructor(params: ConfigParameters) { this.sessionId = params.sessionId; @@ -140,12 +142,14 @@ export class Config { this.telemetryLogUserPromptsEnabled = params.telemetryLogUserPromptsEnabled ?? true; this.telemetryOtlpEndpoint = - process.env.OTEL_EXPORTER_OTLP_ENDPOINT ?? 'http://localhost:4317'; + params.telemetryOtlpEndpoint ?? 'http://localhost:4317'; this.fileFilteringRespectGitIgnore = params.fileFilteringRespectGitIgnore ?? true; this.fileFilteringAllowBuildArtifacts = params.fileFilteringAllowBuildArtifacts ?? false; this.checkpoint = params.checkpoint ?? false; + this.proxy = params.proxy; + this.cwd = params.cwd ?? process.cwd(); if (params.contextFileName) { setGeminiMdFilename(params.contextFileName); @@ -297,6 +301,14 @@ export class Config { return this.checkpoint; } + getProxy(): string | undefined { + return this.proxy; + } + + getWorkingDir(): string { + return this.cwd; + } + async getFileService(): Promise<FileDiscoveryService> { if (!this.fileDiscoveryService) { this.fileDiscoveryService = new FileDiscoveryService(this.targetDir); @@ -317,42 +329,6 @@ export class Config { } } -function findEnvFile(startDir: string): string | null { - let currentDir = path.resolve(startDir); - while (true) { - // prefer gemini-specific .env under GEMINI_DIR - const geminiEnvPath = path.join(currentDir, GEMINI_DIR, '.env'); - if (fs.existsSync(geminiEnvPath)) { - return geminiEnvPath; - } - const envPath = path.join(currentDir, '.env'); - if (fs.existsSync(envPath)) { - return envPath; - } - const parentDir = path.dirname(currentDir); - if (parentDir === currentDir || !parentDir) { - // check .env under home as fallback, again preferring gemini-specific .env - const homeGeminiEnvPath = path.join(os.homedir(), GEMINI_DIR, '.env'); - if (fs.existsSync(homeGeminiEnvPath)) { - return homeGeminiEnvPath; - } - const homeEnvPath = path.join(os.homedir(), '.env'); - if (fs.existsSync(homeEnvPath)) { - return homeEnvPath; - } - return null; - } - currentDir = parentDir; - } -} - -export function loadEnvironment(): void { - const envFilePath = findEnvFile(process.cwd()); - if (envFilePath) { - dotenv.config({ path: envFilePath }); - } -} - export function createToolRegistry(config: Config): Promise<ToolRegistry> { const registry = new ToolRegistry(config); const targetDir = config.getTargetDir(); diff --git a/packages/core/src/core/client.test.ts b/packages/core/src/core/client.test.ts index 0362f72a..d227e57b 100644 --- a/packages/core/src/core/client.test.ts +++ b/packages/core/src/core/client.test.ts @@ -17,6 +17,7 @@ import { ContentGenerator } from './contentGenerator.js'; import { GeminiChat } from './geminiChat.js'; import { Config } from '../config/config.js'; import { Turn } from './turn.js'; +import { getCoreSystemPrompt } from './prompts.js'; // --- Mocks --- const mockChatCreateFn = vi.fn(); @@ -54,6 +55,11 @@ vi.mock('../utils/generateContentResponseUtilities', () => ({ result.candidates?.[0]?.content?.parts?.map((part) => part.text).join('') || undefined, })); +vi.mock('../telemetry/index.js', () => ({ + logApiRequest: vi.fn(), + logApiResponse: vi.fn(), + logApiError: vi.fn(), +})); describe('Gemini Client (client.ts)', () => { let client: GeminiClient; @@ -109,6 +115,8 @@ describe('Gemini Client (client.ts)', () => { getUserMemory: vi.fn().mockReturnValue(''), getFullContext: vi.fn().mockReturnValue(false), getSessionId: vi.fn().mockReturnValue('test-session-id'), + getProxy: vi.fn().mockReturnValue(undefined), + getWorkingDir: vi.fn().mockReturnValue('/test/dir'), }; // eslint-disable-next-line @typescript-eslint/no-explicit-any return mock as any; @@ -239,6 +247,68 @@ describe('Gemini Client (client.ts)', () => { }); }); + describe('generateContent', () => { + it('should call generateContent with the correct parameters', async () => { + const contents = [{ role: 'user', parts: [{ text: 'hello' }] }]; + const generationConfig = { temperature: 0.5 }; + const abortSignal = new AbortController().signal; + + // Mock countTokens + const mockGenerator: Partial<ContentGenerator> = { + countTokens: vi.fn().mockResolvedValue({ totalTokens: 1 }), + generateContent: mockGenerateContentFn, + }; + client['contentGenerator'] = Promise.resolve( + mockGenerator as ContentGenerator, + ); + + await client.generateContent(contents, generationConfig, abortSignal); + + expect(mockGenerateContentFn).toHaveBeenCalledWith({ + model: 'test-model', + config: { + abortSignal, + systemInstruction: getCoreSystemPrompt(''), + temperature: 0.5, + topP: 1, + }, + contents, + }); + }); + }); + + describe('generateJson', () => { + it('should call generateContent with the correct parameters', async () => { + const contents = [{ role: 'user', parts: [{ text: 'hello' }] }]; + const schema = { type: 'string' }; + const abortSignal = new AbortController().signal; + + // Mock countTokens + const mockGenerator: Partial<ContentGenerator> = { + countTokens: vi.fn().mockResolvedValue({ totalTokens: 1 }), + generateContent: mockGenerateContentFn, + }; + client['contentGenerator'] = Promise.resolve( + mockGenerator as ContentGenerator, + ); + + await client.generateJson(contents, schema, abortSignal); + + expect(mockGenerateContentFn).toHaveBeenCalledWith({ + model: 'gemini-2.0-flash', + config: { + abortSignal, + systemInstruction: getCoreSystemPrompt(''), + temperature: 0, + topP: 1, + responseSchema: schema, + responseMimeType: 'application/json', + }, + contents, + }); + }); + }); + describe('addHistory', () => { it('should call chat.addHistory with the provided content', async () => { const mockChat = { diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts index 94cdf0e5..83c322b0 100644 --- a/packages/core/src/core/client.ts +++ b/packages/core/src/core/client.ts @@ -14,7 +14,6 @@ import { Tool, GenerateContentResponse, } from '@google/genai'; -import process from 'node:process'; import { getFolderStructure } from '../utils/getFolderStructure.js'; import { Turn, ServerGeminiStreamEvent, GeminiEventType } from './turn.js'; import { Config } from '../config/config.js'; @@ -33,16 +32,6 @@ import { } from './contentGenerator.js'; import { ProxyAgent, setGlobalDispatcher } from 'undici'; -const proxy = - process.env.HTTPS_PROXY || - process.env.https_proxy || - process.env.HTTP_PROXY || - process.env.http_proxy; - -if (proxy) { - setGlobalDispatcher(new ProxyAgent(proxy)); -} - export class GeminiClient { private chat: Promise<GeminiChat>; private contentGenerator: Promise<ContentGenerator>; @@ -55,6 +44,10 @@ export class GeminiClient { private readonly MAX_TURNS = 100; constructor(private config: Config) { + if (config.getProxy()) { + setGlobalDispatcher(new ProxyAgent(config.getProxy() as string)); + } + this.contentGenerator = createContentGenerator( this.config.getContentGeneratorConfig(), ); @@ -83,7 +76,7 @@ export class GeminiClient { } private async getEnvironment(): Promise<Part[]> { - const cwd = process.cwd(); + const cwd = this.config.getWorkingDir(); const today = new Date().toLocaleDateString(undefined, { weekday: 'long', year: 'numeric', diff --git a/packages/core/src/core/geminiChat.test.ts b/packages/core/src/core/geminiChat.test.ts index 03e933cc..24a7279d 100644 --- a/packages/core/src/core/geminiChat.test.ts +++ b/packages/core/src/core/geminiChat.test.ts @@ -5,7 +5,13 @@ */ import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; -import { Content, Models, GenerateContentConfig, Part } from '@google/genai'; +import { + Content, + Models, + GenerateContentConfig, + Part, + GenerateContentResponse, +} from '@google/genai'; import { GeminiChat } from './geminiChat.js'; import { Config } from '../config/config.js'; @@ -20,6 +26,7 @@ const mockModelsModule = { const mockConfig = { getSessionId: () => 'test-session-id', + getTelemetryLogUserPromptsEnabled: () => true, } as unknown as Config; describe('GeminiChat', () => { @@ -37,6 +44,66 @@ describe('GeminiChat', () => { vi.restoreAllMocks(); }); + describe('sendMessage', () => { + it('should call generateContent with the correct parameters', async () => { + const response = { + candidates: [ + { + content: { + parts: [{ text: 'response' }], + role: 'model', + }, + finishReason: 'STOP', + index: 0, + safetyRatings: [], + }, + ], + text: () => 'response', + } as unknown as GenerateContentResponse; + vi.mocked(mockModelsModule.generateContent).mockResolvedValue(response); + + await chat.sendMessage({ message: 'hello' }); + + expect(mockModelsModule.generateContent).toHaveBeenCalledWith({ + model: 'gemini-pro', + contents: [{ role: 'user', parts: [{ text: 'hello' }] }], + config: {}, + }); + }); + }); + + describe('sendMessageStream', () => { + it('should call generateContentStream with the correct parameters', async () => { + const response = (async function* () { + yield { + candidates: [ + { + content: { + parts: [{ text: 'response' }], + role: 'model', + }, + finishReason: 'STOP', + index: 0, + safetyRatings: [], + }, + ], + text: () => 'response', + } as unknown as GenerateContentResponse; + })(); + vi.mocked(mockModelsModule.generateContentStream).mockResolvedValue( + response, + ); + + await chat.sendMessageStream({ message: 'hello' }); + + expect(mockModelsModule.generateContentStream).toHaveBeenCalledWith({ + model: 'gemini-pro', + contents: [{ role: 'user', parts: [{ text: 'hello' }] }], + config: {}, + }); + }); + }); + describe('recordHistory', () => { const userInput: Content = { role: 'user', diff --git a/packages/core/src/core/geminiRequest.test.ts b/packages/core/src/core/geminiRequest.test.ts new file mode 100644 index 00000000..fd298cb6 --- /dev/null +++ b/packages/core/src/core/geminiRequest.test.ts @@ -0,0 +1,85 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect } from 'vitest'; +import { partListUnionToString } from './geminiRequest.js'; +import { type Part } from '@google/genai'; + +describe('partListUnionToString', () => { + it('should return the string value if the input is a string', () => { + const result = partListUnionToString('hello'); + expect(result).toBe('hello'); + }); + + it('should return a concatenated string if the input is an array of strings', () => { + const result = partListUnionToString(['hello', ' ', 'world']); + expect(result).toBe('hello world'); + }); + + it('should handle videoMetadata', () => { + const part: Part = { videoMetadata: {} }; + const result = partListUnionToString(part); + expect(result).toBe('[Video Metadata]'); + }); + + it('should handle thought', () => { + const part: Part = { thought: true }; + const result = partListUnionToString(part); + expect(result).toBe('[Thought: true]'); + }); + + it('should handle codeExecutionResult', () => { + const part: Part = { codeExecutionResult: {} }; + const result = partListUnionToString(part); + expect(result).toBe('[Code Execution Result]'); + }); + + it('should handle executableCode', () => { + const part: Part = { executableCode: {} }; + const result = partListUnionToString(part); + expect(result).toBe('[Executable Code]'); + }); + + it('should handle fileData', () => { + const part: Part = { + fileData: { mimeType: 'text/plain', fileUri: 'file.txt' }, + }; + const result = partListUnionToString(part); + expect(result).toBe('[File Data]'); + }); + + it('should handle functionCall', () => { + const part: Part = { functionCall: { name: 'myFunction' } }; + const result = partListUnionToString(part); + expect(result).toBe('[Function Call: myFunction]'); + }); + + it('should handle functionResponse', () => { + const part: Part = { + functionResponse: { name: 'myFunction', response: {} }, + }; + const result = partListUnionToString(part); + expect(result).toBe('[Function Response: myFunction]'); + }); + + it('should handle inlineData', () => { + const part: Part = { inlineData: { mimeType: 'image/png', data: '...' } }; + const result = partListUnionToString(part); + expect(result).toBe('<image/png>'); + }); + + it('should handle text', () => { + const part: Part = { text: 'hello' }; + const result = partListUnionToString(part); + expect(result).toBe('hello'); + }); + + it('should return an empty string for an unknown part type', () => { + const part: Part = {}; + const result = partListUnionToString(part); + expect(result).toBe(''); + }); +}); diff --git a/packages/core/src/tools/tool-registry.test.ts b/packages/core/src/tools/tool-registry.test.ts index f4cedcd4..5837fd76 100644 --- a/packages/core/src/tools/tool-registry.test.ts +++ b/packages/core/src/tools/tool-registry.test.ts @@ -124,6 +124,7 @@ class MockTool extends BaseTool<{ param: string }, ToolResult> { } const baseConfigParams: ConfigParameters = { + cwd: '/tmp', contentGeneratorConfig: { model: 'test-model', apiKey: 'test-api-key', |
