summaryrefslogtreecommitdiff
path: root/packages/core/src
diff options
context:
space:
mode:
authorScott Densmore <[email protected]>2025-06-02 14:55:51 -0700
committerGitHub <[email protected]>2025-06-02 14:55:51 -0700
commite428707e074627b0eacce2c2295f34b2ffa28198 (patch)
treeeb977c256c48ea40c661f459e89507f3beb22fba /packages/core/src
parent1dcf0a4cbdee249fd9a20c67b9b718563353773b (diff)
Refactor: Centralize GeminiClient in Config (#693)
Co-authored-by: N. Taylor Mullen <[email protected]>
Diffstat (limited to 'packages/core/src')
-rw-r--r--packages/core/src/config/config.ts7
-rw-r--r--packages/core/src/core/client.ts53
-rw-r--r--packages/core/src/tools/web-fetch.ts35
-rw-r--r--packages/core/src/tools/web-search.ts34
4 files changed, 79 insertions, 50 deletions
diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts
index dc77208c..46e5123c 100644
--- a/packages/core/src/config/config.ts
+++ b/packages/core/src/config/config.ts
@@ -21,6 +21,7 @@ import { WebFetchTool } from '../tools/web-fetch.js';
import { ReadManyFilesTool } from '../tools/read-many-files.js';
import { MemoryTool, setGeminiMdFilename } from '../tools/memoryTool.js';
import { WebSearchTool } from '../tools/web-search.js';
+import { GeminiClient } from '../core/client.js';
import { GEMINI_CONFIG_DIR as GEMINI_DIR } from '../tools/memoryTool.js';
export enum ApprovalMode {
@@ -86,6 +87,7 @@ export class Config {
private approvalMode: ApprovalMode;
private readonly vertexai: boolean | undefined;
private readonly showMemoryUsage: boolean;
+ private readonly geminiClient: GeminiClient;
constructor(params: ConfigParameters) {
this.apiKey = params.apiKey;
@@ -112,6 +114,7 @@ export class Config {
}
this.toolRegistry = createToolRegistry(this);
+ this.geminiClient = new GeminiClient(this);
}
getApiKey(): string {
@@ -200,6 +203,10 @@ export class Config {
getShowMemoryUsage(): boolean {
return this.showMemoryUsage;
}
+
+ getGeminiClient(): GeminiClient {
+ return this.geminiClient;
+ }
}
function findEnvFile(startDir: string): string | null {
diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts
index db30ac16..732126cb 100644
--- a/packages/core/src/core/client.ts
+++ b/packages/core/src/core/client.ts
@@ -12,6 +12,7 @@ import {
PartListUnion,
Content,
Tool,
+ GenerateContentResponse,
} from '@google/genai';
import process from 'node:process';
import { getFolderStructure } from '../utils/getFolderStructure.js';
@@ -262,4 +263,56 @@ export class GeminiClient {
throw new Error(`Failed to generate JSON content: ${message}`);
}
}
+
+ async generateContent(
+ contents: Content[],
+ generationConfig: GenerateContentConfig,
+ abortSignal: AbortSignal,
+ ): Promise<GenerateContentResponse> {
+ const modelToUse = this.model;
+ const configToUse: GenerateContentConfig = {
+ ...this.generateContentConfig,
+ ...generationConfig,
+ };
+
+ try {
+ const userMemory = this.config.getUserMemory();
+ const systemInstruction = getCoreSystemPrompt(userMemory);
+
+ const requestConfig = {
+ abortSignal,
+ ...configToUse,
+ systemInstruction,
+ };
+
+ const apiCall = () =>
+ this.client.models.generateContent({
+ model: modelToUse,
+ config: requestConfig,
+ contents,
+ });
+
+ const result = await retryWithBackoff(apiCall);
+ return result;
+ } catch (error) {
+ if (abortSignal.aborted) {
+ throw error;
+ }
+
+ await reportError(
+ error,
+ `Error generating content via API with model ${modelToUse}.`,
+ {
+ requestContents: contents,
+ requestConfig: configToUse,
+ },
+ 'generateContent-api',
+ );
+ const message =
+ error instanceof Error ? error.message : 'Unknown API error.';
+ throw new Error(
+ `Failed to generate content with model ${modelToUse}: ${message}`,
+ );
+ }
+ }
}
diff --git a/packages/core/src/tools/web-fetch.ts b/packages/core/src/tools/web-fetch.ts
index 24617902..6a6048fc 100644
--- a/packages/core/src/tools/web-fetch.ts
+++ b/packages/core/src/tools/web-fetch.ts
@@ -4,13 +4,12 @@
* SPDX-License-Identifier: Apache-2.0
*/
-import { GoogleGenAI, GroundingMetadata } from '@google/genai';
+import { GroundingMetadata } from '@google/genai';
import { SchemaValidator } from '../utils/schemaValidator.js';
import { BaseTool, ToolResult } from './tools.js';
import { getErrorMessage } from '../utils/errors.js';
import { Config } from '../config/config.js';
import { getResponseText } from '../utils/generateContentResponseUtilities.js';
-import { retryWithBackoff } from '../utils/retry.js';
// Interfaces for grounding metadata (similar to web-search.ts)
interface GroundingChunkWeb {
@@ -49,9 +48,6 @@ export interface WebFetchToolParams {
export class WebFetchTool extends BaseTool<WebFetchToolParams, ToolResult> {
static readonly Name: string = 'web_fetch';
- private ai: GoogleGenAI;
- private modelName: string;
-
constructor(private readonly config: Config) {
super(
WebFetchTool.Name,
@@ -69,12 +65,6 @@ export class WebFetchTool extends BaseTool<WebFetchToolParams, ToolResult> {
type: 'object',
},
);
-
- const apiKeyFromConfig = this.config.getApiKey();
- this.ai = new GoogleGenAI({
- apiKey: apiKeyFromConfig === '' ? undefined : apiKeyFromConfig,
- });
- this.modelName = this.config.getModel();
}
validateParams(params: WebFetchToolParams): string | null {
@@ -109,7 +99,7 @@ export class WebFetchTool extends BaseTool<WebFetchToolParams, ToolResult> {
async execute(
params: WebFetchToolParams,
- _signal: AbortSignal,
+ signal: AbortSignal,
): Promise<ToolResult> {
const validationError = this.validateParams(params);
if (validationError) {
@@ -120,23 +110,14 @@ export class WebFetchTool extends BaseTool<WebFetchToolParams, ToolResult> {
}
const userPrompt = params.prompt;
+ const geminiClient = this.config.getGeminiClient();
try {
- const apiCall = () =>
- this.ai.models.generateContent({
- model: this.modelName,
- contents: [
- {
- role: 'user',
- parts: [{ text: userPrompt }],
- },
- ],
- config: {
- tools: [{ urlContext: {} }],
- },
- });
-
- const response = await retryWithBackoff(apiCall);
+ const response = await geminiClient.generateContent(
+ [{ role: 'user', parts: [{ text: userPrompt }] }],
+ { tools: [{ urlContext: {} }] },
+ signal, // Pass signal
+ );
console.debug(
`[WebFetchTool] Full response for prompt "${userPrompt.substring(0, 50)}...":`,
diff --git a/packages/core/src/tools/web-search.ts b/packages/core/src/tools/web-search.ts
index ed2f341f..c4dcc54a 100644
--- a/packages/core/src/tools/web-search.ts
+++ b/packages/core/src/tools/web-search.ts
@@ -4,14 +4,13 @@
* SPDX-License-Identifier: Apache-2.0
*/
-import { GoogleGenAI, GroundingMetadata } from '@google/genai';
+import { GroundingMetadata } from '@google/genai';
import { BaseTool, ToolResult } from './tools.js';
import { SchemaValidator } from '../utils/schemaValidator.js';
import { getErrorMessage } from '../utils/errors.js';
import { Config } from '../config/config.js';
import { getResponseText } from '../utils/generateContentResponseUtilities.js';
-import { retryWithBackoff } from '../utils/retry.js';
interface GroundingChunkWeb {
uri?: string;
@@ -64,9 +63,6 @@ export class WebSearchTool extends BaseTool<
> {
static readonly Name: string = 'google_web_search';
- private ai: GoogleGenAI;
- private modelName: string;
-
constructor(private readonly config: Config) {
super(
WebSearchTool.Name,
@@ -83,13 +79,6 @@ export class WebSearchTool extends BaseTool<
required: ['query'],
},
);
-
- const apiKeyFromConfig = this.config.getApiKey();
- // Initialize GoogleGenAI, allowing fallback to environment variables for API key
- this.ai = new GoogleGenAI({
- apiKey: apiKeyFromConfig === '' ? undefined : apiKeyFromConfig,
- });
- this.modelName = this.config.getModel();
}
validateParams(params: WebSearchToolParams): string | null {
@@ -112,7 +101,10 @@ export class WebSearchTool extends BaseTool<
return `Searching the web for: "${params.query}"`;
}
- async execute(params: WebSearchToolParams): Promise<WebSearchToolResult> {
+ async execute(
+ params: WebSearchToolParams,
+ signal: AbortSignal,
+ ): Promise<WebSearchToolResult> {
const validationError = this.validateParams(params);
if (validationError) {
return {
@@ -120,18 +112,14 @@ export class WebSearchTool extends BaseTool<
returnDisplay: validationError,
};
}
+ const geminiClient = this.config.getGeminiClient();
try {
- const apiCall = () =>
- this.ai.models.generateContent({
- model: this.modelName,
- contents: [{ role: 'user', parts: [{ text: params.query }] }],
- config: {
- tools: [{ googleSearch: {} }],
- },
- });
-
- const response = await retryWithBackoff(apiCall);
+ const response = await geminiClient.generateContent(
+ [{ role: 'user', parts: [{ text: params.query }] }],
+ { tools: [{ googleSearch: {} }] },
+ signal,
+ );
const responseText = getResponseText(response);
const groundingMetadata = response.candidates?.[0]?.groundingMetadata;