summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEddie Santos <[email protected]>2025-06-07 13:38:05 -0700
committerGitHub <[email protected]>2025-06-07 13:38:05 -0700
commit27fdd1b6e6e50882ee9a17c85c5f6b845d4973ce (patch)
tree1a75c9c755c8ce02e39e0da79342cbd79b359acd
parent51cd5ffd91b1a061ee1da2b048c22cd05ca3e836 (diff)
Add embedder (#818)
-rw-r--r--packages/cli/src/config/config.ts2
-rw-r--r--packages/cli/src/ui/App.test.tsx5
-rw-r--r--packages/core/src/config/config.test.ts2
-rw-r--r--packages/core/src/config/config.ts7
-rw-r--r--packages/core/src/core/client.test.ts168
-rw-r--r--packages/core/src/core/client.ts38
-rw-r--r--packages/core/src/core/contentGenerator.ts4
-rw-r--r--packages/core/src/tools/tool-registry.test.ts1
8 files changed, 206 insertions, 21 deletions
diff --git a/packages/cli/src/config/config.ts b/packages/cli/src/config/config.ts
index 001d17d5..6ab1453f 100644
--- a/packages/cli/src/config/config.ts
+++ b/packages/cli/src/config/config.ts
@@ -33,6 +33,7 @@ const logger = {
export const DEFAULT_GEMINI_MODEL = 'gemini-2.5-pro-preview-06-05';
export const DEFAULT_GEMINI_FLASH_MODEL = 'gemini-2.5-flash-preview-05-20';
+export const DEFAULT_GEMINI_EMBEDDING_MODEL = 'gemini-embedding-001';
interface CliArgs {
model: string | undefined;
@@ -177,6 +178,7 @@ export async function loadCliConfig(
const configParams: ConfigParameters = {
apiKey: apiKeyForServer,
model: modelToUse,
+ embeddingModel: DEFAULT_GEMINI_EMBEDDING_MODEL,
sandbox: argv.sandbox ?? settings.sandbox ?? argv.yolo ?? false,
targetDir: process.cwd(),
debugMode,
diff --git a/packages/cli/src/ui/App.test.tsx b/packages/cli/src/ui/App.test.tsx
index 98d82be8..f4ada985 100644
--- a/packages/cli/src/ui/App.test.tsx
+++ b/packages/cli/src/ui/App.test.tsx
@@ -38,6 +38,7 @@ interface MockServerConfig {
vertexai?: boolean;
showMemoryUsage?: boolean;
accessibility?: AccessibilitySettings;
+ embeddingModel: string;
getApiKey: Mock<() => string>;
getModel: Mock<() => string>;
@@ -92,6 +93,7 @@ vi.mock('@gemini-code/core', async (importOriginal) => {
vertexai: opts.vertexai,
showMemoryUsage: opts.showMemoryUsage ?? false,
accessibility: opts.accessibility ?? {},
+ embeddingModel: opts.embeddingModel || 'test-embedding-model',
getApiKey: vi.fn(() => opts.apiKey || 'test-key'),
getModel: vi.fn(() => opts.model || 'test-model-in-mock-factory'),
@@ -178,7 +180,8 @@ describe('App UI', () => {
const ServerConfigMocked = vi.mocked(ServerConfig, true);
mockConfig = new ServerConfigMocked({
apiKey: 'test-key',
- model: 'test-model-in-options',
+ model: 'test-model',
+ embeddingModel: 'test-embedding-model',
sandbox: false,
targetDir: '/test/dir',
debugMode: false,
diff --git a/packages/core/src/config/config.test.ts b/packages/core/src/config/config.test.ts
index 411b124d..3800585d 100644
--- a/packages/core/src/config/config.test.ts
+++ b/packages/core/src/config/config.test.ts
@@ -48,9 +48,11 @@ describe('Server Config (config.ts)', () => {
const USER_AGENT = 'ServerTestAgent/1.0';
const USER_MEMORY = 'Test User Memory';
const TELEMETRY = false;
+ const EMBEDDING_MODEL = 'gemini-embedding';
const baseParams: ConfigParameters = {
apiKey: API_KEY,
model: MODEL,
+ embeddingModel: EMBEDDING_MODEL,
sandbox: SANDBOX,
targetDir: TARGET_DIR,
debugMode: DEBUG_MODE,
diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts
index 00b3e35d..75db970b 100644
--- a/packages/core/src/config/config.ts
+++ b/packages/core/src/config/config.ts
@@ -55,6 +55,7 @@ export class MCPServerConfig {
export interface ConfigParameters {
apiKey: string;
model: string;
+ embeddingModel: string;
sandbox: boolean | string;
targetDir: string;
debugMode: boolean;
@@ -84,6 +85,7 @@ export class Config {
private toolRegistry: Promise<ToolRegistry>;
private readonly apiKey: string;
private readonly model: string;
+ private readonly embeddingModel: string;
private readonly sandbox: boolean | string;
private readonly targetDir: string;
private readonly debugMode: boolean;
@@ -113,6 +115,7 @@ export class Config {
constructor(params: ConfigParameters) {
this.apiKey = params.apiKey;
this.model = params.model;
+ this.embeddingModel = params.embeddingModel;
this.sandbox = params.sandbox;
this.targetDir = path.resolve(params.targetDir);
this.debugMode = params.debugMode;
@@ -163,6 +166,10 @@ export class Config {
return this.model;
}
+ getEmbeddingModel(): string {
+ return this.embeddingModel;
+ }
+
getSandbox(): boolean | string {
return this.sandbox;
}
diff --git a/packages/core/src/core/client.test.ts b/packages/core/src/core/client.test.ts
index 228701d8..9c12423c 100644
--- a/packages/core/src/core/client.test.ts
+++ b/packages/core/src/core/client.test.ts
@@ -6,29 +6,23 @@
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
-import { Chat, GenerateContentResponse } from '@google/genai';
+import {
+ Chat,
+ EmbedContentResponse,
+ GenerateContentResponse,
+ GoogleGenAI,
+} from '@google/genai';
+import { GeminiClient } from './client.js';
+import { Config } from '../config/config.js';
// --- Mocks ---
const mockChatCreateFn = vi.fn();
const mockGenerateContentFn = vi.fn();
+const mockEmbedContentFn = vi.fn();
-vi.mock('@google/genai', async (importOriginal) => {
- const actual = await importOriginal<typeof import('@google/genai')>();
- const MockedGoogleGenerativeAI = vi
- .fn()
- .mockImplementation((/*...args*/) => ({
- chats: { create: mockChatCreateFn },
- models: { generateContent: mockGenerateContentFn },
- }));
- return {
- ...actual,
- GoogleGenerativeAI: MockedGoogleGenerativeAI,
- Chat: vi.fn(),
- Type: actual.Type ?? { OBJECT: 'OBJECT', STRING: 'STRING' },
- };
-});
+vi.mock('@google/genai');
-vi.mock('../config/config');
+vi.mock('../config/config.js');
vi.mock('./prompts');
vi.mock('../utils/getFolderStructure', () => ({
getFolderStructure: vi.fn().mockResolvedValue('Mock Folder Structure'),
@@ -44,8 +38,24 @@ vi.mock('../utils/generateContentResponseUtilities', () => ({
}));
describe('Gemini Client (client.ts)', () => {
+ let client: GeminiClient;
beforeEach(() => {
vi.resetAllMocks();
+
+ // Set up the mock for GoogleGenAI constructor and its methods
+ const MockedGoogleGenAI = vi.mocked(GoogleGenAI);
+ MockedGoogleGenAI.mockImplementation(() => {
+ const mock = {
+ chats: { create: mockChatCreateFn },
+ models: {
+ generateContent: mockGenerateContentFn,
+ embedContent: mockEmbedContentFn,
+ },
+ };
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
+ return mock as any;
+ });
+
mockChatCreateFn.mockResolvedValue({} as Chat);
mockGenerateContentFn.mockResolvedValue({
candidates: [
@@ -56,6 +66,35 @@ describe('Gemini Client (client.ts)', () => {
},
],
} as unknown as GenerateContentResponse);
+
+ // Because the GeminiClient constructor kicks off an async process (startChat)
+ // that depends on a fully-formed Config object, we need to mock the
+ // entire implementation of Config for these tests.
+ const mockToolRegistry = {
+ getFunctionDeclarations: vi.fn().mockReturnValue([]),
+ getTool: vi.fn().mockReturnValue(null),
+ };
+ const MockedConfig = vi.mocked(Config, true);
+ MockedConfig.mockImplementation(() => {
+ const mock = {
+ getToolRegistry: vi.fn().mockResolvedValue(mockToolRegistry),
+ getModel: vi.fn().mockReturnValue('test-model'),
+ getEmbeddingModel: vi.fn().mockReturnValue('test-embedding-model'),
+ getApiKey: vi.fn().mockReturnValue('test-key'),
+ getVertexAI: vi.fn().mockReturnValue(false),
+ getUserAgent: vi.fn().mockReturnValue('test-agent'),
+ getUserMemory: vi.fn().mockReturnValue(''),
+ getFullContext: vi.fn().mockReturnValue(false),
+ };
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
+ return mock as any;
+ });
+
+ // We can instantiate the client here since Config is mocked
+ // and the constructor will use the mocked GoogleGenAI
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
+ const mockConfig = new Config({} as any);
+ client = new GeminiClient(mockConfig);
});
afterEach(() => {
@@ -82,8 +121,97 @@ describe('Gemini Client (client.ts)', () => {
// it('generateJson should call getCoreSystemPrompt with userMemory and pass to generateContent', async () => { ... });
// it('generateJson should call getCoreSystemPrompt with empty string if userMemory is empty', async () => { ... });
- // Add a placeholder test to keep the suite valid
- it('should have a placeholder test', () => {
- expect(true).toBe(true);
+ describe('generateEmbedding', () => {
+ const texts = ['hello world', 'goodbye world'];
+ const testEmbeddingModel = 'test-embedding-model';
+
+ it('should call embedContent with correct parameters and return embeddings', async () => {
+ const mockEmbeddings = [
+ [0.1, 0.2, 0.3],
+ [0.4, 0.5, 0.6],
+ ];
+ const mockResponse: EmbedContentResponse = {
+ embeddings: [
+ { values: mockEmbeddings[0] },
+ { values: mockEmbeddings[1] },
+ ],
+ };
+ mockEmbedContentFn.mockResolvedValue(mockResponse);
+
+ const result = await client.generateEmbedding(texts);
+
+ expect(mockEmbedContentFn).toHaveBeenCalledTimes(1);
+ expect(mockEmbedContentFn).toHaveBeenCalledWith({
+ model: testEmbeddingModel,
+ contents: texts,
+ });
+ expect(result).toEqual(mockEmbeddings);
+ });
+
+ it('should return an empty array if an empty array is passed', async () => {
+ const result = await client.generateEmbedding([]);
+ expect(result).toEqual([]);
+ expect(mockEmbedContentFn).not.toHaveBeenCalled();
+ });
+
+ it('should throw an error if API response has no embeddings array', async () => {
+ mockEmbedContentFn.mockResolvedValue({} as EmbedContentResponse); // No `embeddings` key
+
+ await expect(client.generateEmbedding(texts)).rejects.toThrow(
+ 'No embeddings found in API response.',
+ );
+ });
+
+ it('should throw an error if API response has an empty embeddings array', async () => {
+ const mockResponse: EmbedContentResponse = {
+ embeddings: [],
+ };
+ mockEmbedContentFn.mockResolvedValue(mockResponse);
+ await expect(client.generateEmbedding(texts)).rejects.toThrow(
+ 'No embeddings found in API response.',
+ );
+ });
+
+ it('should throw an error if API returns a mismatched number of embeddings', async () => {
+ const mockResponse: EmbedContentResponse = {
+ embeddings: [{ values: [1, 2, 3] }], // Only one for two texts
+ };
+ mockEmbedContentFn.mockResolvedValue(mockResponse);
+
+ await expect(client.generateEmbedding(texts)).rejects.toThrow(
+ 'API returned a mismatched number of embeddings. Expected 2, got 1.',
+ );
+ });
+
+ it('should throw an error if any embedding has nullish values', async () => {
+ const mockResponse: EmbedContentResponse = {
+ embeddings: [{ values: [1, 2, 3] }, { values: undefined }], // Second one is bad
+ };
+ mockEmbedContentFn.mockResolvedValue(mockResponse);
+
+ await expect(client.generateEmbedding(texts)).rejects.toThrow(
+ 'API returned an empty embedding for input text at index 1: "goodbye world"',
+ );
+ });
+
+ it('should throw an error if any embedding has an empty values array', async () => {
+ const mockResponse: EmbedContentResponse = {
+ embeddings: [{ values: [] }, { values: [1, 2, 3] }], // First one is bad
+ };
+ mockEmbedContentFn.mockResolvedValue(mockResponse);
+
+ await expect(client.generateEmbedding(texts)).rejects.toThrow(
+ 'API returned an empty embedding for input text at index 0: "hello world"',
+ );
+ });
+
+ it('should propagate errors from the API call', async () => {
+ const apiError = new Error('API Failure');
+ mockEmbedContentFn.mockRejectedValue(apiError);
+
+ await expect(client.generateEmbedding(texts)).rejects.toThrow(
+ 'API Failure',
+ );
+ });
});
});
diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts
index d1a59eb1..c4515f93 100644
--- a/packages/core/src/core/client.ts
+++ b/packages/core/src/core/client.ts
@@ -5,6 +5,8 @@
*/
import {
+ EmbedContentResponse,
+ EmbedContentParameters,
GenerateContentConfig,
GoogleGenAI,
Part,
@@ -38,6 +40,7 @@ export class GeminiClient {
private chat: Promise<GeminiChat>;
private contentGenerator: ContentGenerator;
private model: string;
+ private embeddingModel: string;
private generateContentConfig: GenerateContentConfig = {
temperature: 0,
topP: 1,
@@ -60,6 +63,7 @@ export class GeminiClient {
});
this.contentGenerator = googleGenAI.models;
this.model = config.getModel();
+ this.embeddingModel = config.getEmbeddingModel();
this.chat = this.startChat();
}
@@ -450,6 +454,40 @@ export class GeminiClient {
}
}
+ async generateEmbedding(texts: string[]): Promise<number[][]> {
+ if (!texts || texts.length === 0) {
+ return [];
+ }
+ const embedModelParams: EmbedContentParameters = {
+ model: this.embeddingModel,
+ contents: texts,
+ };
+ const embedContentResponse: EmbedContentResponse =
+ await this.contentGenerator.embedContent(embedModelParams);
+ if (
+ !embedContentResponse.embeddings ||
+ embedContentResponse.embeddings.length === 0
+ ) {
+ throw new Error('No embeddings found in API response.');
+ }
+
+ if (embedContentResponse.embeddings.length !== texts.length) {
+ throw new Error(
+ `API returned a mismatched number of embeddings. Expected ${texts.length}, got ${embedContentResponse.embeddings.length}.`,
+ );
+ }
+
+ return embedContentResponse.embeddings.map((embedding, index) => {
+ const values = embedding.values;
+ if (!values || values.length === 0) {
+ throw new Error(
+ `API returned an empty embedding for input text at index ${index}: "${texts[index]}"`,
+ );
+ }
+ return values;
+ });
+ }
+
private async tryCompressChat(): Promise<boolean> {
const chat = await this.chat;
const history = chat.getHistory(true); // Get curated history
diff --git a/packages/core/src/core/contentGenerator.ts b/packages/core/src/core/contentGenerator.ts
index 32b48c5c..955cd152 100644
--- a/packages/core/src/core/contentGenerator.ts
+++ b/packages/core/src/core/contentGenerator.ts
@@ -9,6 +9,8 @@ import {
GenerateContentResponse,
GenerateContentParameters,
CountTokensParameters,
+ EmbedContentResponse,
+ EmbedContentParameters,
} from '@google/genai';
/**
@@ -24,4 +26,6 @@ export interface ContentGenerator {
): Promise<AsyncGenerator<GenerateContentResponse>>;
countTokens(request: CountTokensParameters): Promise<CountTokensResponse>;
+
+ embedContent(request: EmbedContentParameters): Promise<EmbedContentResponse>;
}
diff --git a/packages/core/src/tools/tool-registry.test.ts b/packages/core/src/tools/tool-registry.test.ts
index f57f5bce..0c23a74e 100644
--- a/packages/core/src/tools/tool-registry.test.ts
+++ b/packages/core/src/tools/tool-registry.test.ts
@@ -126,6 +126,7 @@ class MockTool extends BaseTool<{ param: string }, ToolResult> {
const baseConfigParams: ConfigParameters = {
apiKey: 'test-api-key',
model: 'test-model',
+ embeddingModel: 'test-embedding-model',
sandbox: false,
targetDir: '/test/dir',
debugMode: false,