summaryrefslogtreecommitdiff
path: root/packages/core/src
diff options
context:
space:
mode:
Diffstat (limited to 'packages/core/src')
-rw-r--r--packages/core/src/code_assist/converter.test.ts32
-rw-r--r--packages/core/src/code_assist/converter.ts69
-rw-r--r--packages/core/src/code_assist/server.test.ts7
-rw-r--r--packages/core/src/code_assist/server.ts29
4 files changed, 91 insertions, 46 deletions
diff --git a/packages/core/src/code_assist/converter.test.ts b/packages/core/src/code_assist/converter.test.ts
index d0c05015..2170c960 100644
--- a/packages/core/src/code_assist/converter.test.ts
+++ b/packages/core/src/code_assist/converter.test.ts
@@ -6,9 +6,9 @@
import { describe, it, expect } from 'vitest';
import {
- toCodeAssistRequest,
- fromCodeAsistResponse,
- CodeAssistResponse,
+ toGenerateContentRequest,
+ fromGenerateContentResponse,
+ CaGenerateContentResponse,
} from './converter.js';
import {
GenerateContentParameters,
@@ -24,7 +24,7 @@ describe('converter', () => {
model: 'gemini-pro',
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
};
- const codeAssistReq = toCodeAssistRequest(genaiReq, 'my-project');
+ const codeAssistReq = toGenerateContentRequest(genaiReq, 'my-project');
expect(codeAssistReq).toEqual({
model: 'gemini-pro',
project: 'my-project',
@@ -46,7 +46,7 @@ describe('converter', () => {
model: 'gemini-pro',
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
};
- const codeAssistReq = toCodeAssistRequest(genaiReq);
+ const codeAssistReq = toGenerateContentRequest(genaiReq);
expect(codeAssistReq).toEqual({
model: 'gemini-pro',
project: undefined,
@@ -68,7 +68,7 @@ describe('converter', () => {
model: 'gemini-pro',
contents: 'Hello',
};
- const codeAssistReq = toCodeAssistRequest(genaiReq);
+ const codeAssistReq = toGenerateContentRequest(genaiReq);
expect(codeAssistReq.request.contents).toEqual([
{ role: 'user', parts: [{ text: 'Hello' }] },
]);
@@ -79,7 +79,7 @@ describe('converter', () => {
model: 'gemini-pro',
contents: [{ text: 'Hello' }, { text: 'World' }],
};
- const codeAssistReq = toCodeAssistRequest(genaiReq);
+ const codeAssistReq = toGenerateContentRequest(genaiReq);
expect(codeAssistReq.request.contents).toEqual([
{ role: 'user', parts: [{ text: 'Hello' }] },
{ role: 'user', parts: [{ text: 'World' }] },
@@ -94,7 +94,7 @@ describe('converter', () => {
systemInstruction: 'You are a helpful assistant.',
},
};
- const codeAssistReq = toCodeAssistRequest(genaiReq);
+ const codeAssistReq = toGenerateContentRequest(genaiReq);
expect(codeAssistReq.request.systemInstruction).toEqual({
role: 'user',
parts: [{ text: 'You are a helpful assistant.' }],
@@ -110,7 +110,7 @@ describe('converter', () => {
topK: 40,
},
};
- const codeAssistReq = toCodeAssistRequest(genaiReq);
+ const codeAssistReq = toGenerateContentRequest(genaiReq);
expect(codeAssistReq.request.generationConfig).toEqual({
temperature: 0.8,
topK: 40,
@@ -136,7 +136,7 @@ describe('converter', () => {
responseMimeType: 'application/json',
},
};
- const codeAssistReq = toCodeAssistRequest(genaiReq);
+ const codeAssistReq = toGenerateContentRequest(genaiReq);
expect(codeAssistReq.request.generationConfig).toEqual({
temperature: 0.1,
topP: 0.2,
@@ -156,7 +156,7 @@ describe('converter', () => {
describe('fromCodeAssistResponse', () => {
it('should convert a simple response', () => {
- const codeAssistRes: CodeAssistResponse = {
+ const codeAssistRes: CaGenerateContentResponse = {
response: {
candidates: [
{
@@ -171,13 +171,13 @@ describe('converter', () => {
],
},
};
- const genaiRes = fromCodeAsistResponse(codeAssistRes);
+ const genaiRes = fromGenerateContentResponse(codeAssistRes);
expect(genaiRes).toBeInstanceOf(GenerateContentResponse);
expect(genaiRes.candidates).toEqual(codeAssistRes.response.candidates);
});
it('should handle prompt feedback and usage metadata', () => {
- const codeAssistRes: CodeAssistResponse = {
+ const codeAssistRes: CaGenerateContentResponse = {
response: {
candidates: [],
promptFeedback: {
@@ -191,7 +191,7 @@ describe('converter', () => {
},
},
};
- const genaiRes = fromCodeAsistResponse(codeAssistRes);
+ const genaiRes = fromGenerateContentResponse(codeAssistRes);
expect(genaiRes.promptFeedback).toEqual(
codeAssistRes.response.promptFeedback,
);
@@ -201,7 +201,7 @@ describe('converter', () => {
});
it('should handle automatic function calling history', () => {
- const codeAssistRes: CodeAssistResponse = {
+ const codeAssistRes: CaGenerateContentResponse = {
response: {
candidates: [],
automaticFunctionCallingHistory: [
@@ -221,7 +221,7 @@ describe('converter', () => {
],
},
};
- const genaiRes = fromCodeAsistResponse(codeAssistRes);
+ const genaiRes = fromGenerateContentResponse(codeAssistRes);
expect(genaiRes.automaticFunctionCallingHistory).toEqual(
codeAssistRes.response.automaticFunctionCallingHistory,
);
diff --git a/packages/core/src/code_assist/converter.ts b/packages/core/src/code_assist/converter.ts
index 495cbfae..b9b854fc 100644
--- a/packages/core/src/code_assist/converter.ts
+++ b/packages/core/src/code_assist/converter.ts
@@ -10,6 +10,8 @@ import {
ContentUnion,
GenerateContentConfig,
GenerateContentParameters,
+ CountTokensParameters,
+ CountTokensResponse,
GenerateContentResponse,
GenerationConfigRoutingConfig,
MediaResolution,
@@ -27,13 +29,13 @@ import {
ToolConfig,
} from '@google/genai';
-export interface CodeAssistRequest {
+export interface CAGenerateContentRequest {
model: string;
project?: string;
- request: CodeAssistGenerateContentRequest;
+ request: VertexGenerateContentRequest;
}
-interface CodeAssistGenerateContentRequest {
+interface VertexGenerateContentRequest {
contents: Content[];
systemInstruction?: Content;
cachedContent?: string;
@@ -41,10 +43,10 @@ interface CodeAssistGenerateContentRequest {
toolConfig?: ToolConfig;
labels?: Record<string, string>;
safetySettings?: SafetySetting[];
- generationConfig?: CodeAssistGenerationConfig;
+ generationConfig?: VertexGenerationConfig;
}
-interface CodeAssistGenerationConfig {
+interface VertexGenerationConfig {
temperature?: number;
topP?: number;
topK?: number;
@@ -67,30 +69,61 @@ interface CodeAssistGenerationConfig {
thinkingConfig?: ThinkingConfig;
}
-export interface CodeAssistResponse {
- response: VertexResponse;
+export interface CaGenerateContentResponse {
+ response: VertexGenerateContentResponse;
}
-interface VertexResponse {
+interface VertexGenerateContentResponse {
candidates: Candidate[];
automaticFunctionCallingHistory?: Content[];
promptFeedback?: GenerateContentResponsePromptFeedback;
usageMetadata?: GenerateContentResponseUsageMetadata;
}
+export interface CaCountTokenRequest {
+ request: VertexCountTokenRequest;
+}
+
+interface VertexCountTokenRequest {
+ model: string;
+ contents: Content[];
+}
+
+export interface CaCountTokenResponse {
+ totalTokens: number;
+}
+
+export function toCountTokenRequest(
+ req: CountTokensParameters,
+): CaCountTokenRequest {
+ return {
+ request: {
+ model: 'models/' + req.model,
+ contents: toContents(req.contents),
+ },
+ };
+}
+
+export function fromCountTokenResponse(
+ res: CaCountTokenResponse,
+): CountTokensResponse {
+ return {
+ totalTokens: res.totalTokens,
+ };
+}
-export function toCodeAssistRequest(
+export function toGenerateContentRequest(
req: GenerateContentParameters,
project?: string,
-): CodeAssistRequest {
+): CAGenerateContentRequest {
return {
model: req.model,
project,
- request: toCodeAssistGenerateContentRequest(req),
+ request: toVertexGenerateContentRequest(req),
};
}
-export function fromCodeAsistResponse(
- res: CodeAssistResponse,
+export function fromGenerateContentResponse(
+ res: CaGenerateContentResponse,
): GenerateContentResponse {
const inres = res.response;
const out = new GenerateContentResponse();
@@ -101,9 +134,9 @@ export function fromCodeAsistResponse(
return out;
}
-function toCodeAssistGenerateContentRequest(
+function toVertexGenerateContentRequest(
req: GenerateContentParameters,
-): CodeAssistGenerateContentRequest {
+): VertexGenerateContentRequest {
return {
contents: toContents(req.contents),
systemInstruction: maybeToContent(req.config?.systemInstruction),
@@ -112,7 +145,7 @@ function toCodeAssistGenerateContentRequest(
toolConfig: req.config?.toolConfig,
labels: req.config?.labels,
safetySettings: req.config?.safetySettings,
- generationConfig: toCodeAssistGenerationConfig(req.config),
+ generationConfig: toVertexGenerationConfig(req.config),
};
}
@@ -170,9 +203,9 @@ function toPart(part: PartUnion): Part {
return part;
}
-function toCodeAssistGenerationConfig(
+function toVertexGenerationConfig(
config?: GenerateContentConfig,
-): CodeAssistGenerationConfig | undefined {
+): VertexGenerationConfig | undefined {
if (!config) {
return undefined;
}
diff --git a/packages/core/src/code_assist/server.test.ts b/packages/core/src/code_assist/server.test.ts
index 922d20fb..d8d9c10a 100644
--- a/packages/core/src/code_assist/server.test.ts
+++ b/packages/core/src/code_assist/server.test.ts
@@ -133,11 +133,16 @@ describe('CodeAssistServer', () => {
it('should return 0 for countTokens', async () => {
const auth = new OAuth2Client();
const server = new CodeAssistServer(auth, 'test-project');
+ const mockResponse = {
+ totalTokens: 100,
+ };
+ vi.spyOn(server, 'callEndpoint').mockResolvedValue(mockResponse);
+
const response = await server.countTokens({
model: 'test-model',
contents: [{ role: 'user', parts: [{ text: 'request' }] }],
});
- expect(response.totalTokens).toBe(0);
+ expect(response.totalTokens).toBe(100);
});
it('should throw an error for embedContent', async () => {
diff --git a/packages/core/src/code_assist/server.ts b/packages/core/src/code_assist/server.ts
index d700353c..4f8bb643 100644
--- a/packages/core/src/code_assist/server.ts
+++ b/packages/core/src/code_assist/server.ts
@@ -22,9 +22,12 @@ import {
import * as readline from 'readline';
import { ContentGenerator } from '../core/contentGenerator.js';
import {
- CodeAssistResponse,
- toCodeAssistRequest,
- fromCodeAsistResponse,
+ CaGenerateContentResponse,
+ toGenerateContentRequest,
+ fromGenerateContentResponse,
+ toCountTokenRequest,
+ fromCountTokenResponse,
+ CaCountTokenResponse,
} from './converter.js';
import { PassThrough } from 'node:stream';
@@ -50,14 +53,14 @@ export class CodeAssistServer implements ContentGenerator {
async generateContentStream(
req: GenerateContentParameters,
): Promise<AsyncGenerator<GenerateContentResponse>> {
- const resps = await this.streamEndpoint<CodeAssistResponse>(
+ const resps = await this.streamEndpoint<CaGenerateContentResponse>(
'streamGenerateContent',
- toCodeAssistRequest(req, this.projectId),
+ toGenerateContentRequest(req, this.projectId),
req.config?.abortSignal,
);
return (async function* (): AsyncGenerator<GenerateContentResponse> {
for await (const resp of resps) {
- yield fromCodeAsistResponse(resp);
+ yield fromGenerateContentResponse(resp);
}
})();
}
@@ -65,12 +68,12 @@ export class CodeAssistServer implements ContentGenerator {
async generateContent(
req: GenerateContentParameters,
): Promise<GenerateContentResponse> {
- const resp = await this.callEndpoint<CodeAssistResponse>(
+ const resp = await this.callEndpoint<CaGenerateContentResponse>(
'generateContent',
- toCodeAssistRequest(req, this.projectId),
+ toGenerateContentRequest(req, this.projectId),
req.config?.abortSignal,
);
- return fromCodeAsistResponse(resp);
+ return fromGenerateContentResponse(resp);
}
async onboardUser(
@@ -91,8 +94,12 @@ export class CodeAssistServer implements ContentGenerator {
);
}
- async countTokens(_req: CountTokensParameters): Promise<CountTokensResponse> {
- return { totalTokens: 0 };
+ async countTokens(req: CountTokensParameters): Promise<CountTokensResponse> {
+ const resp = await this.callEndpoint<CaCountTokenResponse>(
+ 'countTokens',
+ toCountTokenRequest(req),
+ );
+ return fromCountTokenResponse(resp);
}
async embedContent(