summaryrefslogtreecommitdiff
path: root/packages/core/src
diff options
context:
space:
mode:
authorAbhi <[email protected]>2025-06-22 01:35:36 -0400
committerGitHub <[email protected]>2025-06-22 01:35:36 -0400
commitc9950b3cb273246d801a5cbb04cf421d4c5e39c4 (patch)
tree0acd0de4ef11c6031c70489bba6063bbba4ca8f1 /packages/core/src
parent5cf8dc4f0784408f4c2fcfc56d6e834facccf4a3 (diff)
feat: Add client-initiated tool call handling (#1292)
Diffstat (limited to 'packages/core/src')
-rw-r--r--packages/core/src/core/coreToolScheduler.test.ts7
-rw-r--r--packages/core/src/core/nonInteractiveToolExecutor.test.ts5
-rw-r--r--packages/core/src/core/turn.test.ts25
-rw-r--r--packages/core/src/core/turn.ts19
4 files changed, 43 insertions, 13 deletions
diff --git a/packages/core/src/core/coreToolScheduler.test.ts b/packages/core/src/core/coreToolScheduler.test.ts
index 8e09a2af..63feb874 100644
--- a/packages/core/src/core/coreToolScheduler.test.ts
+++ b/packages/core/src/core/coreToolScheduler.test.ts
@@ -88,7 +88,12 @@ describe('CoreToolScheduler', () => {
});
const abortController = new AbortController();
- const request = { callId: '1', name: 'mockTool', args: {} };
+ const request = {
+ callId: '1',
+ name: 'mockTool',
+ args: {},
+ isClientInitiated: false,
+ };
abortController.abort();
await scheduler.schedule([request], abortController.signal);
diff --git a/packages/core/src/core/nonInteractiveToolExecutor.test.ts b/packages/core/src/core/nonInteractiveToolExecutor.test.ts
index 07946af5..edf11d35 100644
--- a/packages/core/src/core/nonInteractiveToolExecutor.test.ts
+++ b/packages/core/src/core/nonInteractiveToolExecutor.test.ts
@@ -62,6 +62,7 @@ describe('executeToolCall', () => {
callId: 'call1',
name: 'testTool',
args: { param1: 'value1' },
+ isClientInitiated: false,
};
const toolResult: ToolResult = {
llmContent: 'Tool executed successfully',
@@ -99,6 +100,7 @@ describe('executeToolCall', () => {
callId: 'call2',
name: 'nonExistentTool',
args: {},
+ isClientInitiated: false,
};
vi.mocked(mockToolRegistry.getTool).mockReturnValue(undefined);
@@ -133,6 +135,7 @@ describe('executeToolCall', () => {
callId: 'call3',
name: 'testTool',
args: { param1: 'value1' },
+ isClientInitiated: false,
};
const executionError = new Error('Tool execution failed');
vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockTool);
@@ -164,6 +167,7 @@ describe('executeToolCall', () => {
callId: 'call4',
name: 'testTool',
args: { param1: 'value1' },
+ isClientInitiated: false,
};
const cancellationError = new Error('Operation cancelled');
vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockTool);
@@ -206,6 +210,7 @@ describe('executeToolCall', () => {
callId: 'call5',
name: 'testTool',
args: {},
+ isClientInitiated: false,
};
const imageDataPart: Part = {
inlineData: { mimeType: 'image/png', data: 'base64data' },
diff --git a/packages/core/src/core/turn.test.ts b/packages/core/src/core/turn.test.ts
index aeb30229..a525cbff 100644
--- a/packages/core/src/core/turn.test.ts
+++ b/packages/core/src/core/turn.test.ts
@@ -132,8 +132,13 @@ describe('Turn', () => {
const mockResponseStream = (async function* () {
yield {
functionCalls: [
- { id: 'fc1', name: 'tool1', args: { arg1: 'val1' } },
- { name: 'tool2', args: { arg2: 'val2' } }, // No ID
+ {
+ id: 'fc1',
+ name: 'tool1',
+ args: { arg1: 'val1' },
+ isClientInitiated: false,
+ },
+ { name: 'tool2', args: { arg2: 'val2' }, isClientInitiated: false }, // No ID
],
} as unknown as GenerateContentResponse;
})();
@@ -156,6 +161,7 @@ describe('Turn', () => {
callId: 'fc1',
name: 'tool1',
args: { arg1: 'val1' },
+ isClientInitiated: false,
}),
);
expect(turn.pendingToolCalls[0]).toEqual(event1.value);
@@ -163,7 +169,11 @@ describe('Turn', () => {
const event2 = events[1] as ServerGeminiToolCallRequestEvent;
expect(event2.type).toBe(GeminiEventType.ToolCallRequest);
expect(event2.value).toEqual(
- expect.objectContaining({ name: 'tool2', args: { arg2: 'val2' } }),
+ expect.objectContaining({
+ name: 'tool2',
+ args: { arg2: 'val2' },
+ isClientInitiated: false,
+ }),
);
expect(event2.value.callId).toEqual(
expect.stringMatching(/^tool2-\d{13}-\w{10,}$/),
@@ -301,6 +311,7 @@ describe('Turn', () => {
callId: 'fc1',
name: 'undefined_tool_name',
args: { arg1: 'val1' },
+ isClientInitiated: false,
}),
);
expect(turn.pendingToolCalls[0]).toEqual(event1.value);
@@ -308,7 +319,12 @@ describe('Turn', () => {
const event2 = events[1] as ServerGeminiToolCallRequestEvent;
expect(event2.type).toBe(GeminiEventType.ToolCallRequest);
expect(event2.value).toEqual(
- expect.objectContaining({ callId: 'fc2', name: 'tool2', args: {} }),
+ expect.objectContaining({
+ callId: 'fc2',
+ name: 'tool2',
+ args: {},
+ isClientInitiated: false,
+ }),
);
expect(turn.pendingToolCalls[1]).toEqual(event2.value);
@@ -319,6 +335,7 @@ describe('Turn', () => {
callId: 'fc3',
name: 'undefined_tool_name',
args: {},
+ isClientInitiated: false,
}),
);
expect(turn.pendingToolCalls[2]).toEqual(event3.value);
diff --git a/packages/core/src/core/turn.ts b/packages/core/src/core/turn.ts
index 4cc4bf4d..cdb4a89f 100644
--- a/packages/core/src/core/turn.ts
+++ b/packages/core/src/core/turn.ts
@@ -57,6 +57,7 @@ export interface ToolCallRequestInfo {
callId: string;
name: string;
args: Record<string, unknown>;
+ isClientInitiated: boolean;
}
export interface ToolCallResponseInfo {
@@ -139,11 +140,7 @@ export type ServerGeminiStreamEvent =
// A turn manages the agentic loop turn within the server context.
export class Turn {
- readonly pendingToolCalls: Array<{
- callId: string;
- name: string;
- args: Record<string, unknown>;
- }>;
+ readonly pendingToolCalls: ToolCallRequestInfo[];
private debugResponses: GenerateContentResponse[];
private lastUsageMetadata: GenerateContentResponseUsageMetadata | null = null;
@@ -254,11 +251,17 @@ export class Turn {
const name = fnCall.name || 'undefined_tool_name';
const args = (fnCall.args || {}) as Record<string, unknown>;
- this.pendingToolCalls.push({ callId, name, args });
+ const toolCallRequest: ToolCallRequestInfo = {
+ callId,
+ name,
+ args,
+ isClientInitiated: false,
+ };
+
+ this.pendingToolCalls.push(toolCallRequest);
// Yield a request for the tool call, not the pending/confirming status
- const value: ToolCallRequestInfo = { callId, name, args };
- return { type: GeminiEventType.ToolCallRequest, value };
+ return { type: GeminiEventType.ToolCallRequest, value: toolCallRequest };
}
getDebugResponses(): GenerateContentResponse[] {