diff options
Diffstat (limited to 'packages/core/src/config')
| -rw-r--r-- | packages/core/src/config/config.test.ts | 47 | ||||
| -rw-r--r-- | packages/core/src/config/config.ts | 31 |
2 files changed, 66 insertions, 12 deletions
diff --git a/packages/core/src/config/config.test.ts b/packages/core/src/config/config.test.ts index 10fd814f..8a9f038c 100644 --- a/packages/core/src/config/config.test.ts +++ b/packages/core/src/config/config.test.ts @@ -42,6 +42,21 @@ vi.mock('../tools/memoryTool', () => ({ GEMINI_CONFIG_DIR: '.gemini', })); +vi.mock('../core/contentGenerator.js', async (importOriginal) => { + const actual = + await importOriginal<typeof import('../core/contentGenerator.js')>(); + return { + ...actual, + createContentGeneratorConfig: vi.fn(), + }; +}); + +vi.mock('../core/client.js', () => ({ + GeminiClient: vi.fn().mockImplementation(() => ({ + // Mock any methods on GeminiClient that might be used. + })), +})); + vi.mock('../telemetry/index.js', async (importOriginal) => { const actual = await importOriginal<typeof import('../telemetry/index.js')>(); return { @@ -51,7 +66,6 @@ vi.mock('../telemetry/index.js', async (importOriginal) => { }); describe('Server Config (config.ts)', () => { - const API_KEY = 'server-api-key'; const MODEL = 'gemini-pro'; const SANDBOX: SandboxConfig = { command: 'docker', @@ -67,10 +81,6 @@ describe('Server Config (config.ts)', () => { const SESSION_ID = 'test-session-id'; const baseParams: ConfigParameters = { cwd: '/tmp', - contentGeneratorConfig: { - apiKey: API_KEY, - model: MODEL, - }, embeddingModel: EMBEDDING_MODEL, sandbox: SANDBOX, targetDir: TARGET_DIR, @@ -80,6 +90,7 @@ describe('Server Config (config.ts)', () => { userMemory: USER_MEMORY, telemetry: TELEMETRY_SETTINGS, sessionId: SESSION_ID, + model: MODEL, }; beforeEach(() => { @@ -87,6 +98,32 @@ describe('Server Config (config.ts)', () => { vi.clearAllMocks(); }); + // i can't get vi mocking to import in core. only in cli. can't fix it now. + // describe('refreshAuth', () => { + // it('should refresh auth and update config', async () => { + // const config = new Config(baseParams); + // const newModel = 'gemini-ultra'; + // const authType = AuthType.USE_GEMINI; + // const mockContentConfig = { + // model: newModel, + // apiKey: 'test-key', + // }; + + // (createContentGeneratorConfig as vi.Mock).mockResolvedValue( + // mockContentConfig, + // ); + + // await config.refreshAuth(authType); + + // expect(createContentGeneratorConfig).toHaveBeenCalledWith( + // newModel, + // authType, + // ); + // expect(config.getContentGeneratorConfig()).toEqual(mockContentConfig); + // expect(GeminiClient).toHaveBeenCalledWith(config); + // }); + // }); + it('Config constructor should store userMemory correctly', () => { const config = new Config(baseParams); diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index 6caf39e8..a97f5536 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -6,7 +6,11 @@ import * as path from 'node:path'; import process from 'node:process'; -import { ContentGeneratorConfig } from '../core/contentGenerator.js'; +import { + AuthType, + ContentGeneratorConfig, + createContentGeneratorConfig, +} from '../core/contentGenerator.js'; import { ToolRegistry } from '../tools/tool-registry.js'; import { LSTool } from '../tools/ls.js'; import { ReadFileTool } from '../tools/read-file.js'; @@ -80,7 +84,6 @@ export interface SandboxConfig { export interface ConfigParameters { sessionId: string; - contentGeneratorConfig: ContentGeneratorConfig; embeddingModel?: string; sandbox?: SandboxConfig; targetDir: string; @@ -106,12 +109,13 @@ export interface ConfigParameters { cwd: string; fileDiscoveryService?: FileDiscoveryService; bugCommand?: BugCommandSettings; + model: string; } export class Config { private toolRegistry: Promise<ToolRegistry>; private readonly sessionId: string; - private readonly contentGeneratorConfig: ContentGeneratorConfig; + private contentGeneratorConfig!: ContentGeneratorConfig; private readonly embeddingModel: string; private readonly sandbox: SandboxConfig | undefined; private readonly targetDir: string; @@ -130,7 +134,7 @@ export class Config { private readonly showMemoryUsage: boolean; private readonly accessibility: AccessibilitySettings; private readonly telemetrySettings: TelemetrySettings; - private readonly geminiClient: GeminiClient; + private geminiClient!: GeminiClient; private readonly fileFilteringRespectGitIgnore: boolean; private fileDiscoveryService: FileDiscoveryService | null = null; private gitService: GitService | undefined = undefined; @@ -138,10 +142,10 @@ export class Config { private readonly proxy: string | undefined; private readonly cwd: string; private readonly bugCommand: BugCommandSettings | undefined; + private readonly model: string; constructor(params: ConfigParameters) { this.sessionId = params.sessionId; - this.contentGeneratorConfig = params.contentGeneratorConfig; this.embeddingModel = params.embeddingModel ?? DEFAULT_GEMINI_EMBEDDING_MODEL; this.sandbox = params.sandbox; @@ -174,12 +178,12 @@ export class Config { this.cwd = params.cwd ?? process.cwd(); this.fileDiscoveryService = params.fileDiscoveryService ?? null; this.bugCommand = params.bugCommand; + this.model = params.model; if (params.contextFileName) { setGeminiMdFilename(params.contextFileName); } - this.geminiClient = new GeminiClient(this); this.toolRegistry = createToolRegistry(this); if (this.telemetrySettings.enabled) { @@ -187,6 +191,19 @@ export class Config { } } + async refreshAuth(authMethod: AuthType) { + const contentConfig = await createContentGeneratorConfig( + this.getModel(), + authMethod, + ); + + const gc = new GeminiClient(this); + await gc.initialize(contentConfig); + + this.contentGeneratorConfig = contentConfig; + this.geminiClient = gc; + } + getSessionId(): string { return this.sessionId; } @@ -196,7 +213,7 @@ export class Config { } getModel(): string { - return this.contentGeneratorConfig.model; + return this.contentGeneratorConfig?.model || this.model; } getEmbeddingModel(): string { |
