summaryrefslogtreecommitdiff
path: root/packages/core/src
diff options
context:
space:
mode:
authorBryan Morgan <[email protected]>2025-07-01 19:16:09 -0400
committerGitHub <[email protected]>2025-07-01 23:16:09 +0000
commitdbe88f6e0e8efb989b21fc8b46e0da124f5204ff (patch)
treebfb9e5e2f15acd925e24a182086362c68397bde9 /packages/core/src
parent3492c429b95b7e905cd7cc7538e95b38809cc53e (diff)
Added support for session_id in API calls (#2886)
Diffstat (limited to 'packages/core/src')
-rw-r--r--packages/core/src/code_assist/codeAssist.ts3
-rw-r--r--packages/core/src/code_assist/converter.test.ts29
-rw-r--r--packages/core/src/code_assist/converter.ts6
-rw-r--r--packages/core/src/code_assist/server.ts5
-rw-r--r--packages/core/src/core/client.ts1
-rw-r--r--packages/core/src/core/contentGenerator.ts7
6 files changed, 46 insertions, 5 deletions
diff --git a/packages/core/src/code_assist/codeAssist.ts b/packages/core/src/code_assist/codeAssist.ts
index c3cb9293..80d95ca9 100644
--- a/packages/core/src/code_assist/codeAssist.ts
+++ b/packages/core/src/code_assist/codeAssist.ts
@@ -12,11 +12,12 @@ import { CodeAssistServer, HttpOptions } from './server.js';
export async function createCodeAssistContentGenerator(
httpOptions: HttpOptions,
authType: AuthType,
+ sessionId?: string,
): Promise<ContentGenerator> {
if (authType === AuthType.LOGIN_WITH_GOOGLE) {
const authClient = await getOauthClient();
const projectId = await setupUser(authClient);
- return new CodeAssistServer(authClient, projectId, httpOptions);
+ return new CodeAssistServer(authClient, projectId, httpOptions, sessionId);
}
throw new Error(`Unsupported authType: ${authType}`);
diff --git a/packages/core/src/code_assist/converter.test.ts b/packages/core/src/code_assist/converter.test.ts
index 2170c960..03f388dc 100644
--- a/packages/core/src/code_assist/converter.test.ts
+++ b/packages/core/src/code_assist/converter.test.ts
@@ -37,6 +37,7 @@ describe('converter', () => {
labels: undefined,
safetySettings: undefined,
generationConfig: undefined,
+ session_id: undefined,
},
});
});
@@ -59,6 +60,34 @@ describe('converter', () => {
labels: undefined,
safetySettings: undefined,
generationConfig: undefined,
+ session_id: undefined,
+ },
+ });
+ });
+
+ it('should convert a request with sessionId', () => {
+ const genaiReq: GenerateContentParameters = {
+ model: 'gemini-pro',
+ contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
+ };
+ const codeAssistReq = toGenerateContentRequest(
+ genaiReq,
+ 'my-project',
+ 'session-123',
+ );
+ expect(codeAssistReq).toEqual({
+ model: 'gemini-pro',
+ project: 'my-project',
+ request: {
+ contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
+ systemInstruction: undefined,
+ cachedContent: undefined,
+ tools: undefined,
+ toolConfig: undefined,
+ labels: undefined,
+ safetySettings: undefined,
+ generationConfig: undefined,
+ session_id: 'session-123',
},
});
});
diff --git a/packages/core/src/code_assist/converter.ts b/packages/core/src/code_assist/converter.ts
index b9b854fc..b27617c4 100644
--- a/packages/core/src/code_assist/converter.ts
+++ b/packages/core/src/code_assist/converter.ts
@@ -44,6 +44,7 @@ interface VertexGenerateContentRequest {
labels?: Record<string, string>;
safetySettings?: SafetySetting[];
generationConfig?: VertexGenerationConfig;
+ session_id?: string;
}
interface VertexGenerationConfig {
@@ -114,11 +115,12 @@ export function fromCountTokenResponse(
export function toGenerateContentRequest(
req: GenerateContentParameters,
project?: string,
+ sessionId?: string,
): CAGenerateContentRequest {
return {
model: req.model,
project,
- request: toVertexGenerateContentRequest(req),
+ request: toVertexGenerateContentRequest(req, sessionId),
};
}
@@ -136,6 +138,7 @@ export function fromGenerateContentResponse(
function toVertexGenerateContentRequest(
req: GenerateContentParameters,
+ sessionId?: string,
): VertexGenerateContentRequest {
return {
contents: toContents(req.contents),
@@ -146,6 +149,7 @@ function toVertexGenerateContentRequest(
labels: req.config?.labels,
safetySettings: req.config?.safetySettings,
generationConfig: toVertexGenerationConfig(req.config),
+ session_id: sessionId,
};
}
diff --git a/packages/core/src/code_assist/server.ts b/packages/core/src/code_assist/server.ts
index 3cf0c721..f285dba8 100644
--- a/packages/core/src/code_assist/server.ts
+++ b/packages/core/src/code_assist/server.ts
@@ -48,6 +48,7 @@ export class CodeAssistServer implements ContentGenerator {
readonly client: OAuth2Client,
readonly projectId?: string,
readonly httpOptions: HttpOptions = {},
+ readonly sessionId?: string,
) {}
async generateContentStream(
@@ -55,7 +56,7 @@ export class CodeAssistServer implements ContentGenerator {
): Promise<AsyncGenerator<GenerateContentResponse>> {
const resps = await this.requestStreamingPost<CaGenerateContentResponse>(
'streamGenerateContent',
- toGenerateContentRequest(req, this.projectId),
+ toGenerateContentRequest(req, this.projectId, this.sessionId),
req.config?.abortSignal,
);
return (async function* (): AsyncGenerator<GenerateContentResponse> {
@@ -70,7 +71,7 @@ export class CodeAssistServer implements ContentGenerator {
): Promise<GenerateContentResponse> {
const resp = await this.requestPost<CaGenerateContentResponse>(
'generateContent',
- toGenerateContentRequest(req, this.projectId),
+ toGenerateContentRequest(req, this.projectId, this.sessionId),
req.config?.abortSignal,
);
return fromGenerateContentResponse(resp);
diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts
index b00a689b..fe60112d 100644
--- a/packages/core/src/core/client.ts
+++ b/packages/core/src/core/client.ts
@@ -68,6 +68,7 @@ export class GeminiClient {
async initialize(contentGeneratorConfig: ContentGeneratorConfig) {
this.contentGenerator = await createContentGenerator(
contentGeneratorConfig,
+ this.config.getSessionId(),
);
this.chat = await this.startChat();
}
diff --git a/packages/core/src/core/contentGenerator.ts b/packages/core/src/core/contentGenerator.ts
index 4740c4ee..f0c163d2 100644
--- a/packages/core/src/core/contentGenerator.ts
+++ b/packages/core/src/core/contentGenerator.ts
@@ -101,6 +101,7 @@ export async function createContentGeneratorConfig(
export async function createContentGenerator(
config: ContentGeneratorConfig,
+ sessionId?: string,
): Promise<ContentGenerator> {
const version = process.env.CLI_VERSION || process.version;
const httpOptions = {
@@ -109,7 +110,11 @@ export async function createContentGenerator(
},
};
if (config.authType === AuthType.LOGIN_WITH_GOOGLE) {
- return createCodeAssistContentGenerator(httpOptions, config.authType);
+ return createCodeAssistContentGenerator(
+ httpOptions,
+ config.authType,
+ sessionId,
+ );
}
if (