summaryrefslogtreecommitdiff
path: root/packages/core/src/config
diff options
context:
space:
mode:
Diffstat (limited to 'packages/core/src/config')
-rw-r--r--packages/core/src/config/config.test.ts47
-rw-r--r--packages/core/src/config/config.ts31
2 files changed, 66 insertions, 12 deletions
diff --git a/packages/core/src/config/config.test.ts b/packages/core/src/config/config.test.ts
index 10fd814f..8a9f038c 100644
--- a/packages/core/src/config/config.test.ts
+++ b/packages/core/src/config/config.test.ts
@@ -42,6 +42,21 @@ vi.mock('../tools/memoryTool', () => ({
GEMINI_CONFIG_DIR: '.gemini',
}));
+vi.mock('../core/contentGenerator.js', async (importOriginal) => {
+ const actual =
+ await importOriginal<typeof import('../core/contentGenerator.js')>();
+ return {
+ ...actual,
+ createContentGeneratorConfig: vi.fn(),
+ };
+});
+
+vi.mock('../core/client.js', () => ({
+ GeminiClient: vi.fn().mockImplementation(() => ({
+ // Mock any methods on GeminiClient that might be used.
+ })),
+}));
+
vi.mock('../telemetry/index.js', async (importOriginal) => {
const actual = await importOriginal<typeof import('../telemetry/index.js')>();
return {
@@ -51,7 +66,6 @@ vi.mock('../telemetry/index.js', async (importOriginal) => {
});
describe('Server Config (config.ts)', () => {
- const API_KEY = 'server-api-key';
const MODEL = 'gemini-pro';
const SANDBOX: SandboxConfig = {
command: 'docker',
@@ -67,10 +81,6 @@ describe('Server Config (config.ts)', () => {
const SESSION_ID = 'test-session-id';
const baseParams: ConfigParameters = {
cwd: '/tmp',
- contentGeneratorConfig: {
- apiKey: API_KEY,
- model: MODEL,
- },
embeddingModel: EMBEDDING_MODEL,
sandbox: SANDBOX,
targetDir: TARGET_DIR,
@@ -80,6 +90,7 @@ describe('Server Config (config.ts)', () => {
userMemory: USER_MEMORY,
telemetry: TELEMETRY_SETTINGS,
sessionId: SESSION_ID,
+ model: MODEL,
};
beforeEach(() => {
@@ -87,6 +98,32 @@ describe('Server Config (config.ts)', () => {
vi.clearAllMocks();
});
+ // i can't get vi mocking to import in core. only in cli. can't fix it now.
+ // describe('refreshAuth', () => {
+ // it('should refresh auth and update config', async () => {
+ // const config = new Config(baseParams);
+ // const newModel = 'gemini-ultra';
+ // const authType = AuthType.USE_GEMINI;
+ // const mockContentConfig = {
+ // model: newModel,
+ // apiKey: 'test-key',
+ // };
+
+ // (createContentGeneratorConfig as vi.Mock).mockResolvedValue(
+ // mockContentConfig,
+ // );
+
+ // await config.refreshAuth(authType);
+
+ // expect(createContentGeneratorConfig).toHaveBeenCalledWith(
+ // newModel,
+ // authType,
+ // );
+ // expect(config.getContentGeneratorConfig()).toEqual(mockContentConfig);
+ // expect(GeminiClient).toHaveBeenCalledWith(config);
+ // });
+ // });
+
it('Config constructor should store userMemory correctly', () => {
const config = new Config(baseParams);
diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts
index 6caf39e8..a97f5536 100644
--- a/packages/core/src/config/config.ts
+++ b/packages/core/src/config/config.ts
@@ -6,7 +6,11 @@
import * as path from 'node:path';
import process from 'node:process';
-import { ContentGeneratorConfig } from '../core/contentGenerator.js';
+import {
+ AuthType,
+ ContentGeneratorConfig,
+ createContentGeneratorConfig,
+} from '../core/contentGenerator.js';
import { ToolRegistry } from '../tools/tool-registry.js';
import { LSTool } from '../tools/ls.js';
import { ReadFileTool } from '../tools/read-file.js';
@@ -80,7 +84,6 @@ export interface SandboxConfig {
export interface ConfigParameters {
sessionId: string;
- contentGeneratorConfig: ContentGeneratorConfig;
embeddingModel?: string;
sandbox?: SandboxConfig;
targetDir: string;
@@ -106,12 +109,13 @@ export interface ConfigParameters {
cwd: string;
fileDiscoveryService?: FileDiscoveryService;
bugCommand?: BugCommandSettings;
+ model: string;
}
export class Config {
private toolRegistry: Promise<ToolRegistry>;
private readonly sessionId: string;
- private readonly contentGeneratorConfig: ContentGeneratorConfig;
+ private contentGeneratorConfig!: ContentGeneratorConfig;
private readonly embeddingModel: string;
private readonly sandbox: SandboxConfig | undefined;
private readonly targetDir: string;
@@ -130,7 +134,7 @@ export class Config {
private readonly showMemoryUsage: boolean;
private readonly accessibility: AccessibilitySettings;
private readonly telemetrySettings: TelemetrySettings;
- private readonly geminiClient: GeminiClient;
+ private geminiClient!: GeminiClient;
private readonly fileFilteringRespectGitIgnore: boolean;
private fileDiscoveryService: FileDiscoveryService | null = null;
private gitService: GitService | undefined = undefined;
@@ -138,10 +142,10 @@ export class Config {
private readonly proxy: string | undefined;
private readonly cwd: string;
private readonly bugCommand: BugCommandSettings | undefined;
+ private readonly model: string;
constructor(params: ConfigParameters) {
this.sessionId = params.sessionId;
- this.contentGeneratorConfig = params.contentGeneratorConfig;
this.embeddingModel =
params.embeddingModel ?? DEFAULT_GEMINI_EMBEDDING_MODEL;
this.sandbox = params.sandbox;
@@ -174,12 +178,12 @@ export class Config {
this.cwd = params.cwd ?? process.cwd();
this.fileDiscoveryService = params.fileDiscoveryService ?? null;
this.bugCommand = params.bugCommand;
+ this.model = params.model;
if (params.contextFileName) {
setGeminiMdFilename(params.contextFileName);
}
- this.geminiClient = new GeminiClient(this);
this.toolRegistry = createToolRegistry(this);
if (this.telemetrySettings.enabled) {
@@ -187,6 +191,19 @@ export class Config {
}
}
+ async refreshAuth(authMethod: AuthType) {
+ const contentConfig = await createContentGeneratorConfig(
+ this.getModel(),
+ authMethod,
+ );
+
+ const gc = new GeminiClient(this);
+ await gc.initialize(contentConfig);
+
+ this.contentGeneratorConfig = contentConfig;
+ this.geminiClient = gc;
+ }
+
getSessionId(): string {
return this.sessionId;
}
@@ -196,7 +213,7 @@ export class Config {
}
getModel(): string {
- return this.contentGeneratorConfig.model;
+ return this.contentGeneratorConfig?.model || this.model;
}
getEmbeddingModel(): string {