summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--packages/core/src/code_assist/server.test.ts18
-rw-r--r--packages/core/src/code_assist/server.ts34
2 files changed, 28 insertions, 24 deletions
diff --git a/packages/core/src/code_assist/server.test.ts b/packages/core/src/code_assist/server.test.ts
index 9bcfa304..6944af47 100644
--- a/packages/core/src/code_assist/server.test.ts
+++ b/packages/core/src/code_assist/server.test.ts
@@ -35,14 +35,14 @@ describe('CodeAssistServer', () => {
],
},
};
- vi.spyOn(server, 'callEndpoint').mockResolvedValue(mockResponse);
+ vi.spyOn(server, 'requestPost').mockResolvedValue(mockResponse);
const response = await server.generateContent({
model: 'test-model',
contents: [{ role: 'user', parts: [{ text: 'request' }] }],
});
- expect(server.callEndpoint).toHaveBeenCalledWith(
+ expect(server.requestPost).toHaveBeenCalledWith(
'generateContent',
expect.any(Object),
undefined,
@@ -72,7 +72,7 @@ describe('CodeAssistServer', () => {
},
};
})();
- vi.spyOn(server, 'streamEndpoint').mockResolvedValue(mockResponse);
+ vi.spyOn(server, 'requestStreamingPost').mockResolvedValue(mockResponse);
const stream = await server.generateContentStream({
model: 'test-model',
@@ -80,7 +80,7 @@ describe('CodeAssistServer', () => {
});
for await (const res of stream) {
- expect(server.streamEndpoint).toHaveBeenCalledWith(
+ expect(server.requestStreamingPost).toHaveBeenCalledWith(
'streamGenerateContent',
expect.any(Object),
undefined,
@@ -96,7 +96,7 @@ describe('CodeAssistServer', () => {
name: 'operations/123',
done: true,
};
- vi.spyOn(server, 'callEndpoint').mockResolvedValue(mockResponse);
+ vi.spyOn(server, 'requestPost').mockResolvedValue(mockResponse);
const response = await server.onboardUser({
tierId: 'test-tier',
@@ -104,7 +104,7 @@ describe('CodeAssistServer', () => {
metadata: {},
});
- expect(server.callEndpoint).toHaveBeenCalledWith(
+ expect(server.requestPost).toHaveBeenCalledWith(
'onboardUser',
expect.any(Object),
);
@@ -117,13 +117,13 @@ describe('CodeAssistServer', () => {
const mockResponse = {
// TODO: Add mock response
};
- vi.spyOn(server, 'callEndpoint').mockResolvedValue(mockResponse);
+ vi.spyOn(server, 'requestPost').mockResolvedValue(mockResponse);
const response = await server.loadCodeAssist({
metadata: {},
});
- expect(server.callEndpoint).toHaveBeenCalledWith(
+ expect(server.requestPost).toHaveBeenCalledWith(
'loadCodeAssist',
expect.any(Object),
);
@@ -136,7 +136,7 @@ describe('CodeAssistServer', () => {
const mockResponse = {
totalTokens: 100,
};
- vi.spyOn(server, 'callEndpoint').mockResolvedValue(mockResponse);
+ vi.spyOn(server, 'requestPost').mockResolvedValue(mockResponse);
const response = await server.countTokens({
model: 'test-model',
diff --git a/packages/core/src/code_assist/server.ts b/packages/core/src/code_assist/server.ts
index 8e74c8b2..3cf0c721 100644
--- a/packages/core/src/code_assist/server.ts
+++ b/packages/core/src/code_assist/server.ts
@@ -40,8 +40,7 @@ export interface HttpOptions {
}
// TODO: Use production endpoint once it supports our methods.
-export const CODE_ASSIST_ENDPOINT =
- process.env.CODE_ASSIST_ENDPOINT ?? 'https://cloudcode-pa.googleapis.com';
+export const CODE_ASSIST_ENDPOINT = 'https://cloudcode-pa.googleapis.com';
export const CODE_ASSIST_API_VERSION = 'v1internal';
export class CodeAssistServer implements ContentGenerator {
@@ -54,7 +53,7 @@ export class CodeAssistServer implements ContentGenerator {
async generateContentStream(
req: GenerateContentParameters,
): Promise<AsyncGenerator<GenerateContentResponse>> {
- const resps = await this.streamEndpoint<CaGenerateContentResponse>(
+ const resps = await this.requestStreamingPost<CaGenerateContentResponse>(
'streamGenerateContent',
toGenerateContentRequest(req, this.projectId),
req.config?.abortSignal,
@@ -69,7 +68,7 @@ export class CodeAssistServer implements ContentGenerator {
async generateContent(
req: GenerateContentParameters,
): Promise<GenerateContentResponse> {
- const resp = await this.callEndpoint<CaGenerateContentResponse>(
+ const resp = await this.requestPost<CaGenerateContentResponse>(
'generateContent',
toGenerateContentRequest(req, this.projectId),
req.config?.abortSignal,
@@ -80,7 +79,7 @@ export class CodeAssistServer implements ContentGenerator {
async onboardUser(
req: OnboardUserRequest,
): Promise<LongrunningOperationResponse> {
- return await this.callEndpoint<LongrunningOperationResponse>(
+ return await this.requestPost<LongrunningOperationResponse>(
'onboardUser',
req,
);
@@ -89,14 +88,14 @@ export class CodeAssistServer implements ContentGenerator {
async loadCodeAssist(
req: LoadCodeAssistRequest,
): Promise<LoadCodeAssistResponse> {
- return await this.callEndpoint<LoadCodeAssistResponse>(
+ return await this.requestPost<LoadCodeAssistResponse>(
'loadCodeAssist',
req,
);
}
async getCodeAssistGlobalUserSetting(): Promise<CodeAssistGlobalUserSettingResponse> {
- return await this.getEndpoint<CodeAssistGlobalUserSettingResponse>(
+ return await this.requestGet<CodeAssistGlobalUserSettingResponse>(
'getCodeAssistGlobalUserSetting',
);
}
@@ -104,14 +103,14 @@ export class CodeAssistServer implements ContentGenerator {
async setCodeAssistGlobalUserSetting(
req: SetCodeAssistGlobalUserSettingRequest,
): Promise<CodeAssistGlobalUserSettingResponse> {
- return await this.callEndpoint<CodeAssistGlobalUserSettingResponse>(
+ return await this.requestPost<CodeAssistGlobalUserSettingResponse>(
'setCodeAssistGlobalUserSetting',
req,
);
}
async countTokens(req: CountTokensParameters): Promise<CountTokensResponse> {
- const resp = await this.callEndpoint<CaCountTokenResponse>(
+ const resp = await this.requestPost<CaCountTokenResponse>(
'countTokens',
toCountTokenRequest(req),
);
@@ -124,13 +123,13 @@ export class CodeAssistServer implements ContentGenerator {
throw Error();
}
- async callEndpoint<T>(
+ async requestPost<T>(
method: string,
req: object,
signal?: AbortSignal,
): Promise<T> {
const res = await this.client.request({
- url: `${CODE_ASSIST_ENDPOINT}/${CODE_ASSIST_API_VERSION}:${method}`,
+ url: this.getMethodUrl(method),
method: 'POST',
headers: {
'Content-Type': 'application/json',
@@ -143,9 +142,9 @@ export class CodeAssistServer implements ContentGenerator {
return res.data as T;
}
- async getEndpoint<T>(method: string, signal?: AbortSignal): Promise<T> {
+ async requestGet<T>(method: string, signal?: AbortSignal): Promise<T> {
const res = await this.client.request({
- url: `${CODE_ASSIST_ENDPOINT}/${CODE_ASSIST_API_VERSION}:${method}`,
+ url: this.getMethodUrl(method),
method: 'GET',
headers: {
'Content-Type': 'application/json',
@@ -157,13 +156,13 @@ export class CodeAssistServer implements ContentGenerator {
return res.data as T;
}
- async streamEndpoint<T>(
+ async requestStreamingPost<T>(
method: string,
req: object,
signal?: AbortSignal,
): Promise<AsyncGenerator<T>> {
const res = await this.client.request({
- url: `${CODE_ASSIST_ENDPOINT}/${CODE_ASSIST_API_VERSION}:${method}`,
+ url: this.getMethodUrl(method),
method: 'POST',
params: {
alt: 'sse',
@@ -200,4 +199,9 @@ export class CodeAssistServer implements ContentGenerator {
}
})();
}
+
+ getMethodUrl(method: string): string {
+ const endpoint = process.env.CODE_ASSIST_ENDPOINT ?? CODE_ASSIST_ENDPOINT;
+ return `${endpoint}/${CODE_ASSIST_API_VERSION}:${method}`;
+ }
}