summaryrefslogtreecommitdiff
path: root/packages/server/src/core
diff options
context:
space:
mode:
Diffstat (limited to 'packages/server/src/core')
-rw-r--r--packages/server/src/core/client.ts13
-rw-r--r--packages/server/src/core/geminiChat.ts4
-rw-r--r--packages/server/src/core/turn.test.ts27
-rw-r--r--packages/server/src/core/turn.ts12
4 files changed, 44 insertions, 12 deletions
diff --git a/packages/server/src/core/client.ts b/packages/server/src/core/client.ts
index 341ce021..69b815ab 100644
--- a/packages/server/src/core/client.ts
+++ b/packages/server/src/core/client.ts
@@ -157,7 +157,7 @@ export class GeminiClient {
async *sendMessageStream(
chat: GeminiChat,
request: PartListUnion,
- signal?: AbortSignal,
+ signal: AbortSignal,
turns: number = this.MAX_TURNS,
): AsyncGenerator<ServerGeminiStreamEvent> {
if (!turns) {
@@ -169,8 +169,8 @@ export class GeminiClient {
for await (const event of resultStream) {
yield event;
}
- if (!turn.pendingToolCalls.length) {
- const nextSpeakerCheck = await checkNextSpeaker(chat, this);
+ if (!turn.pendingToolCalls.length && signal && !signal.aborted) {
+ const nextSpeakerCheck = await checkNextSpeaker(chat, this, signal);
if (nextSpeakerCheck?.next_speaker === 'model') {
const nextRequest = [{ text: 'Please continue.' }];
yield* this.sendMessageStream(chat, nextRequest, signal, turns - 1);
@@ -181,6 +181,7 @@ export class GeminiClient {
async generateJson(
contents: Content[],
schema: SchemaUnion,
+ abortSignal: AbortSignal,
model: string = 'gemini-2.0-flash',
config: GenerateContentConfig = {},
): Promise<Record<string, unknown>> {
@@ -188,6 +189,7 @@ export class GeminiClient {
const userMemory = this.config.getUserMemory();
const systemInstruction = getCoreSystemPrompt(userMemory);
const requestConfig = {
+ abortSignal,
...this.generateContentConfig,
...config,
};
@@ -232,6 +234,11 @@ export class GeminiClient {
);
}
} catch (error) {
+ if (abortSignal.aborted) {
+ // Regular cancellation error, fail normally
+ throw error;
+ }
+
// Avoid double reporting for the empty response case handled above
if (
error instanceof Error &&
diff --git a/packages/server/src/core/geminiChat.ts b/packages/server/src/core/geminiChat.ts
index c971e2cc..5ba8ce2d 100644
--- a/packages/server/src/core/geminiChat.ts
+++ b/packages/server/src/core/geminiChat.ts
@@ -155,7 +155,7 @@ export class GeminiChat {
const responsePromise = this.modelsModule.generateContent({
model: this.model,
contents: this.getHistory(true).concat(userContent),
- config: params.config ?? this.config,
+ config: { ...this.config, ...params.config },
});
this.sendPromise = (async () => {
const response = await responsePromise;
@@ -219,7 +219,7 @@ export class GeminiChat {
const streamResponse = this.modelsModule.generateContentStream({
model: this.model,
contents: this.getHistory(true).concat(userContent),
- config: params.config ?? this.config,
+ config: { ...this.config, ...params.config },
});
// Resolve the internal tracking of send completion promise - `sendPromise`
// for both success and failure response. The actual failure is still
diff --git a/packages/server/src/core/turn.test.ts b/packages/server/src/core/turn.test.ts
index 44bb983f..8fb3a4c1 100644
--- a/packages/server/src/core/turn.test.ts
+++ b/packages/server/src/core/turn.test.ts
@@ -85,11 +85,17 @@ describe('Turn', () => {
const events = [];
const reqParts: Part[] = [{ text: 'Hi' }];
- for await (const event of turn.run(reqParts)) {
+ for await (const event of turn.run(
+ reqParts,
+ new AbortController().signal,
+ )) {
events.push(event);
}
- expect(mockSendMessageStream).toHaveBeenCalledWith({ message: reqParts });
+ expect(mockSendMessageStream).toHaveBeenCalledWith({
+ message: reqParts,
+ config: { abortSignal: expect.any(AbortSignal) },
+ });
expect(events).toEqual([
{ type: GeminiEventType.Content, value: 'Hello' },
{ type: GeminiEventType.Content, value: ' world' },
@@ -110,7 +116,10 @@ describe('Turn', () => {
const events = [];
const reqParts: Part[] = [{ text: 'Use tools' }];
- for await (const event of turn.run(reqParts)) {
+ for await (const event of turn.run(
+ reqParts,
+ new AbortController().signal,
+ )) {
events.push(event);
}
@@ -179,7 +188,10 @@ describe('Turn', () => {
mockGetHistory.mockReturnValue(historyContent);
const events = [];
- for await (const event of turn.run(reqParts)) {
+ for await (const event of turn.run(
+ reqParts,
+ new AbortController().signal,
+ )) {
events.push(event);
}
@@ -210,7 +222,10 @@ describe('Turn', () => {
const events = [];
const reqParts: Part[] = [{ text: 'Test undefined tool parts' }];
- for await (const event of turn.run(reqParts)) {
+ for await (const event of turn.run(
+ reqParts,
+ new AbortController().signal,
+ )) {
events.push(event);
}
@@ -261,7 +276,7 @@ describe('Turn', () => {
})();
mockSendMessageStream.mockResolvedValue(mockResponseStream);
const reqParts: Part[] = [{ text: 'Hi' }];
- for await (const _ of turn.run(reqParts)) {
+ for await (const _ of turn.run(reqParts, new AbortController().signal)) {
// consume stream
}
expect(turn.getDebugResponses()).toEqual([resp1, resp2]);
diff --git a/packages/server/src/core/turn.ts b/packages/server/src/core/turn.ts
index d5c7eb58..97e93f59 100644
--- a/packages/server/src/core/turn.ts
+++ b/packages/server/src/core/turn.ts
@@ -32,6 +32,7 @@ export interface ServerTool {
): Promise<ToolResult>;
shouldConfirmExecute(
params: Record<string, unknown>,
+ abortSignal: AbortSignal,
): Promise<ToolCallConfirmationDetails | false>;
}
@@ -120,11 +121,14 @@ export class Turn {
// The run method yields simpler events suitable for server logic
async *run(
req: PartListUnion,
- signal?: AbortSignal,
+ signal: AbortSignal,
): AsyncGenerator<ServerGeminiStreamEvent> {
try {
const responseStream = await this.chat.sendMessageStream({
message: req,
+ config: {
+ abortSignal: signal,
+ },
});
for await (const resp of responseStream) {
@@ -150,6 +154,12 @@ export class Turn {
}
}
} catch (error) {
+ if (signal.aborted) {
+ yield { type: GeminiEventType.UserCancelled };
+ // Regular cancellation error, fail gracefully.
+ return;
+ }
+
const contextForReport = [...this.chat.getHistory(/*curated*/ true), req];
await reportError(
error,