diff options
Diffstat (limited to 'packages/core/src/code_assist')
| -rw-r--r-- | packages/core/src/code_assist/converter.test.ts | 57 | ||||
| -rw-r--r-- | packages/core/src/code_assist/converter.ts | 3 | ||||
| -rw-r--r-- | packages/core/src/code_assist/server.test.ts | 78 | ||||
| -rw-r--r-- | packages/core/src/code_assist/server.ts | 16 | ||||
| -rw-r--r-- | packages/core/src/code_assist/setup.test.ts | 10 | ||||
| -rw-r--r-- | packages/core/src/code_assist/setup.ts | 2 |
6 files changed, 137 insertions, 29 deletions
diff --git a/packages/core/src/code_assist/converter.test.ts b/packages/core/src/code_assist/converter.test.ts index 03f388dc..3d3a8ef3 100644 --- a/packages/core/src/code_assist/converter.test.ts +++ b/packages/core/src/code_assist/converter.test.ts @@ -24,7 +24,12 @@ describe('converter', () => { model: 'gemini-pro', contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], }; - const codeAssistReq = toGenerateContentRequest(genaiReq, 'my-project'); + const codeAssistReq = toGenerateContentRequest( + genaiReq, + 'my-prompt', + 'my-project', + 'my-session', + ); expect(codeAssistReq).toEqual({ model: 'gemini-pro', project: 'my-project', @@ -37,8 +42,9 @@ describe('converter', () => { labels: undefined, safetySettings: undefined, generationConfig: undefined, - session_id: undefined, + session_id: 'my-session', }, + user_prompt_id: 'my-prompt', }); }); @@ -47,7 +53,12 @@ describe('converter', () => { model: 'gemini-pro', contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], }; - const codeAssistReq = toGenerateContentRequest(genaiReq); + const codeAssistReq = toGenerateContentRequest( + genaiReq, + 'my-prompt', + undefined, + 'my-session', + ); expect(codeAssistReq).toEqual({ model: 'gemini-pro', project: undefined, @@ -60,8 +71,9 @@ describe('converter', () => { labels: undefined, safetySettings: undefined, generationConfig: undefined, - session_id: undefined, + session_id: 'my-session', }, + user_prompt_id: 'my-prompt', }); }); @@ -72,6 +84,7 @@ describe('converter', () => { }; const codeAssistReq = toGenerateContentRequest( genaiReq, + 'my-prompt', 'my-project', 'session-123', ); @@ -89,6 +102,7 @@ describe('converter', () => { generationConfig: undefined, session_id: 'session-123', }, + user_prompt_id: 'my-prompt', }); }); @@ -97,7 +111,12 @@ describe('converter', () => { model: 'gemini-pro', contents: 'Hello', }; - const codeAssistReq = toGenerateContentRequest(genaiReq); + const codeAssistReq = toGenerateContentRequest( + genaiReq, + 'my-prompt', + 'my-project', + 'my-session', + ); expect(codeAssistReq.request.contents).toEqual([ { role: 'user', parts: [{ text: 'Hello' }] }, ]); @@ -108,7 +127,12 @@ describe('converter', () => { model: 'gemini-pro', contents: [{ text: 'Hello' }, { text: 'World' }], }; - const codeAssistReq = toGenerateContentRequest(genaiReq); + const codeAssistReq = toGenerateContentRequest( + genaiReq, + 'my-prompt', + 'my-project', + 'my-session', + ); expect(codeAssistReq.request.contents).toEqual([ { role: 'user', parts: [{ text: 'Hello' }] }, { role: 'user', parts: [{ text: 'World' }] }, @@ -123,7 +147,12 @@ describe('converter', () => { systemInstruction: 'You are a helpful assistant.', }, }; - const codeAssistReq = toGenerateContentRequest(genaiReq); + const codeAssistReq = toGenerateContentRequest( + genaiReq, + 'my-prompt', + 'my-project', + 'my-session', + ); expect(codeAssistReq.request.systemInstruction).toEqual({ role: 'user', parts: [{ text: 'You are a helpful assistant.' }], @@ -139,7 +168,12 @@ describe('converter', () => { topK: 40, }, }; - const codeAssistReq = toGenerateContentRequest(genaiReq); + const codeAssistReq = toGenerateContentRequest( + genaiReq, + 'my-prompt', + 'my-project', + 'my-session', + ); expect(codeAssistReq.request.generationConfig).toEqual({ temperature: 0.8, topK: 40, @@ -165,7 +199,12 @@ describe('converter', () => { responseMimeType: 'application/json', }, }; - const codeAssistReq = toGenerateContentRequest(genaiReq); + const codeAssistReq = toGenerateContentRequest( + genaiReq, + 'my-prompt', + 'my-project', + 'my-session', + ); expect(codeAssistReq.request.generationConfig).toEqual({ temperature: 0.1, topP: 0.2, diff --git a/packages/core/src/code_assist/converter.ts b/packages/core/src/code_assist/converter.ts index 8340cfc1..ffd471da 100644 --- a/packages/core/src/code_assist/converter.ts +++ b/packages/core/src/code_assist/converter.ts @@ -32,6 +32,7 @@ import { export interface CAGenerateContentRequest { model: string; project?: string; + user_prompt_id?: string; request: VertexGenerateContentRequest; } @@ -115,12 +116,14 @@ export function fromCountTokenResponse( export function toGenerateContentRequest( req: GenerateContentParameters, + userPromptId: string, project?: string, sessionId?: string, ): CAGenerateContentRequest { return { model: req.model, project, + user_prompt_id: userPromptId, request: toVertexGenerateContentRequest(req, sessionId), }; } diff --git a/packages/core/src/code_assist/server.test.ts b/packages/core/src/code_assist/server.test.ts index 6246fd4e..3fc1891f 100644 --- a/packages/core/src/code_assist/server.test.ts +++ b/packages/core/src/code_assist/server.test.ts @@ -14,13 +14,25 @@ vi.mock('google-auth-library'); describe('CodeAssistServer', () => { it('should be able to be constructed', () => { const auth = new OAuth2Client(); - const server = new CodeAssistServer(auth, 'test-project'); + const server = new CodeAssistServer( + auth, + 'test-project', + {}, + 'test-session', + UserTierId.FREE, + ); expect(server).toBeInstanceOf(CodeAssistServer); }); it('should call the generateContent endpoint', async () => { const client = new OAuth2Client(); - const server = new CodeAssistServer(client, 'test-project'); + const server = new CodeAssistServer( + client, + 'test-project', + {}, + 'test-session', + UserTierId.FREE, + ); const mockResponse = { response: { candidates: [ @@ -38,10 +50,13 @@ describe('CodeAssistServer', () => { }; vi.spyOn(server, 'requestPost').mockResolvedValue(mockResponse); - const response = await server.generateContent({ - model: 'test-model', - contents: [{ role: 'user', parts: [{ text: 'request' }] }], - }); + const response = await server.generateContent( + { + model: 'test-model', + contents: [{ role: 'user', parts: [{ text: 'request' }] }], + }, + 'user-prompt-id', + ); expect(server.requestPost).toHaveBeenCalledWith( 'generateContent', @@ -55,7 +70,13 @@ describe('CodeAssistServer', () => { it('should call the generateContentStream endpoint', async () => { const client = new OAuth2Client(); - const server = new CodeAssistServer(client, 'test-project'); + const server = new CodeAssistServer( + client, + 'test-project', + {}, + 'test-session', + UserTierId.FREE, + ); const mockResponse = (async function* () { yield { response: { @@ -75,10 +96,13 @@ describe('CodeAssistServer', () => { })(); vi.spyOn(server, 'requestStreamingPost').mockResolvedValue(mockResponse); - const stream = await server.generateContentStream({ - model: 'test-model', - contents: [{ role: 'user', parts: [{ text: 'request' }] }], - }); + const stream = await server.generateContentStream( + { + model: 'test-model', + contents: [{ role: 'user', parts: [{ text: 'request' }] }], + }, + 'user-prompt-id', + ); for await (const res of stream) { expect(server.requestStreamingPost).toHaveBeenCalledWith( @@ -92,7 +116,13 @@ describe('CodeAssistServer', () => { it('should call the onboardUser endpoint', async () => { const client = new OAuth2Client(); - const server = new CodeAssistServer(client, 'test-project'); + const server = new CodeAssistServer( + client, + 'test-project', + {}, + 'test-session', + UserTierId.FREE, + ); const mockResponse = { name: 'operations/123', done: true, @@ -114,7 +144,13 @@ describe('CodeAssistServer', () => { it('should call the loadCodeAssist endpoint', async () => { const client = new OAuth2Client(); - const server = new CodeAssistServer(client, 'test-project'); + const server = new CodeAssistServer( + client, + 'test-project', + {}, + 'test-session', + UserTierId.FREE, + ); const mockResponse = { currentTier: { id: UserTierId.FREE, @@ -140,7 +176,13 @@ describe('CodeAssistServer', () => { it('should return 0 for countTokens', async () => { const client = new OAuth2Client(); - const server = new CodeAssistServer(client, 'test-project'); + const server = new CodeAssistServer( + client, + 'test-project', + {}, + 'test-session', + UserTierId.FREE, + ); const mockResponse = { totalTokens: 100, }; @@ -155,7 +197,13 @@ describe('CodeAssistServer', () => { it('should throw an error for embedContent', async () => { const client = new OAuth2Client(); - const server = new CodeAssistServer(client, 'test-project'); + const server = new CodeAssistServer( + client, + 'test-project', + {}, + 'test-session', + UserTierId.FREE, + ); await expect( server.embedContent({ model: 'test-model', diff --git a/packages/core/src/code_assist/server.ts b/packages/core/src/code_assist/server.ts index 7af643f7..08339bdc 100644 --- a/packages/core/src/code_assist/server.ts +++ b/packages/core/src/code_assist/server.ts @@ -53,10 +53,16 @@ export class CodeAssistServer implements ContentGenerator { async generateContentStream( req: GenerateContentParameters, + userPromptId: string, ): Promise<AsyncGenerator<GenerateContentResponse>> { const resps = await this.requestStreamingPost<CaGenerateContentResponse>( 'streamGenerateContent', - toGenerateContentRequest(req, this.projectId, this.sessionId), + toGenerateContentRequest( + req, + userPromptId, + this.projectId, + this.sessionId, + ), req.config?.abortSignal, ); return (async function* (): AsyncGenerator<GenerateContentResponse> { @@ -68,10 +74,16 @@ export class CodeAssistServer implements ContentGenerator { async generateContent( req: GenerateContentParameters, + userPromptId: string, ): Promise<GenerateContentResponse> { const resp = await this.requestPost<CaGenerateContentResponse>( 'generateContent', - toGenerateContentRequest(req, this.projectId, this.sessionId), + toGenerateContentRequest( + req, + userPromptId, + this.projectId, + this.sessionId, + ), req.config?.abortSignal, ); return fromGenerateContentResponse(resp); diff --git a/packages/core/src/code_assist/setup.test.ts b/packages/core/src/code_assist/setup.test.ts index 6db5fd88..c1260e3f 100644 --- a/packages/core/src/code_assist/setup.test.ts +++ b/packages/core/src/code_assist/setup.test.ts @@ -49,8 +49,11 @@ describe('setupUser', () => { }); await setupUser({} as OAuth2Client); expect(CodeAssistServer).toHaveBeenCalledWith( - expect.any(Object), + {}, 'test-project', + {}, + '', + undefined, ); }); @@ -62,7 +65,10 @@ describe('setupUser', () => { }); const projectId = await setupUser({} as OAuth2Client); expect(CodeAssistServer).toHaveBeenCalledWith( - expect.any(Object), + {}, + undefined, + {}, + '', undefined, ); expect(projectId).toEqual({ diff --git a/packages/core/src/code_assist/setup.ts b/packages/core/src/code_assist/setup.ts index 8831d24b..9c7a8043 100644 --- a/packages/core/src/code_assist/setup.ts +++ b/packages/core/src/code_assist/setup.ts @@ -34,7 +34,7 @@ export interface UserData { */ export async function setupUser(client: OAuth2Client): Promise<UserData> { let projectId = process.env.GOOGLE_CLOUD_PROJECT || undefined; - const caServer = new CodeAssistServer(client, projectId); + const caServer = new CodeAssistServer(client, projectId, {}, '', undefined); const clientMetadata: ClientMetadata = { ideType: 'IDE_UNSPECIFIED', |
