summaryrefslogtreecommitdiff
path: root/packages
diff options
context:
space:
mode:
Diffstat (limited to 'packages')
-rw-r--r--packages/cli/src/ui/hooks/atCommandProcessor.ts4
-rw-r--r--packages/cli/src/ui/hooks/useGeminiStream.ts115
-rw-r--r--packages/server/src/core/client.ts11
-rw-r--r--packages/server/src/core/turn.ts13
-rw-r--r--packages/server/src/tools/edit.ts5
-rw-r--r--packages/server/src/tools/glob.ts5
-rw-r--r--packages/server/src/tools/grep.ts5
-rw-r--r--packages/server/src/tools/ls.ts5
-rw-r--r--packages/server/src/tools/read-file.ts5
-rw-r--r--packages/server/src/tools/read-many-files.ts5
-rw-r--r--packages/server/src/tools/shell.ts70
-rw-r--r--packages/server/src/tools/terminal.ts5
-rw-r--r--packages/server/src/tools/tools.ts5
-rw-r--r--packages/server/src/tools/web-fetch.ts5
-rw-r--r--packages/server/src/tools/write-file.ts5
15 files changed, 191 insertions, 72 deletions
diff --git a/packages/cli/src/ui/hooks/atCommandProcessor.ts b/packages/cli/src/ui/hooks/atCommandProcessor.ts
index 5ffa5383..a13a7d36 100644
--- a/packages/cli/src/ui/hooks/atCommandProcessor.ts
+++ b/packages/cli/src/ui/hooks/atCommandProcessor.ts
@@ -26,6 +26,7 @@ interface HandleAtCommandParams {
addItem: UseHistoryManagerReturn['addItem'];
setDebugMessage: React.Dispatch<React.SetStateAction<string>>;
messageId: number;
+ signal: AbortSignal;
}
interface HandleAtCommandResult {
@@ -90,6 +91,7 @@ export async function handleAtCommand({
addItem,
setDebugMessage,
messageId: userMessageTimestamp,
+ signal,
}: HandleAtCommandParams): Promise<HandleAtCommandResult> {
const trimmedQuery = query.trim();
const parsedCommand = parseAtCommand(trimmedQuery);
@@ -163,7 +165,7 @@ export async function handleAtCommand({
let toolCallDisplay: IndividualToolCallDisplay;
try {
- const result = await readManyFilesTool.execute(toolArgs);
+ const result = await readManyFilesTool.execute(toolArgs, signal);
const fileContent = result.llmContent || '';
toolCallDisplay = {
diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts
index 3f8cee40..e86ae0b9 100644
--- a/packages/cli/src/ui/hooks/useGeminiStream.ts
+++ b/packages/cli/src/ui/hooks/useGeminiStream.ts
@@ -89,7 +89,7 @@ export const useGeminiStream = (
}, [config, addItem]);
useInput((_input, key) => {
- if (streamingState === StreamingState.Responding && key.escape) {
+ if (streamingState !== StreamingState.Idle && key.escape) {
abortControllerRef.current?.abort();
}
});
@@ -104,6 +104,9 @@ export const useGeminiStream = (
setShowHelp(false);
+ abortControllerRef.current ??= new AbortController();
+ const signal = abortControllerRef.current.signal;
+
if (typeof query === 'string') {
const trimmedQuery = query.trim();
setDebugMessage(`User query: '${trimmedQuery}'`);
@@ -120,6 +123,7 @@ export const useGeminiStream = (
addItem,
setDebugMessage,
messageId: userMessageTimestamp,
+ signal,
});
if (!atCommandResult.shouldProceed) return;
queryToSendToGemini = atCommandResult.processedQuery;
@@ -165,9 +169,6 @@ export const useGeminiStream = (
const chat = chatSessionRef.current;
try {
- abortControllerRef.current = new AbortController();
- const signal = abortControllerRef.current.signal;
-
const stream = client.sendMessageStream(
chat,
queryToSendToGemini,
@@ -294,7 +295,26 @@ export const useGeminiStream = (
} else if (event.type === ServerGeminiEventType.UserCancelled) {
// Flush out existing pending history item.
if (pendingHistoryItemRef.current) {
- addItem(pendingHistoryItemRef.current, userMessageTimestamp);
+ // If the pending item is a tool_group, update statuses to Canceled
+ if (pendingHistoryItemRef.current.type === 'tool_group') {
+ const updatedTools = pendingHistoryItemRef.current.tools.map(
+ (tool) => {
+ if (
+ tool.status === ToolCallStatus.Pending ||
+ tool.status === ToolCallStatus.Confirming ||
+ tool.status === ToolCallStatus.Executing
+ ) {
+ return { ...tool, status: ToolCallStatus.Canceled };
+ }
+ return tool;
+ },
+ );
+ const pendingHistoryItem = pendingHistoryItemRef.current;
+ pendingHistoryItem.tools = updatedTools;
+ addItem(pendingHistoryItem, userMessageTimestamp);
+ } else {
+ addItem(pendingHistoryItemRef.current, userMessageTimestamp);
+ }
setPendingHistoryItem(null);
}
addItem(
@@ -412,6 +432,59 @@ export const useGeminiStream = (
}
if (outcome === ToolConfirmationOutcome.Cancel) {
+ declineToolExecution(
+ 'User rejected function call.',
+ ToolCallStatus.Error,
+ );
+ } else {
+ const tool = toolRegistry.getTool(request.name);
+ if (!tool) {
+ throw new Error(
+ `Tool "${request.name}" not found or is not registered.`,
+ );
+ }
+
+ try {
+ abortControllerRef.current = new AbortController();
+ const result = await tool.execute(
+ request.args,
+ abortControllerRef.current.signal,
+ );
+
+ if (abortControllerRef.current.signal.aborted) {
+ declineToolExecution(
+ result.llmContent,
+ ToolCallStatus.Canceled,
+ );
+ return;
+ }
+
+ const functionResponse: Part = {
+ functionResponse: {
+ name: request.name,
+ id: request.callId,
+ response: { output: result.llmContent },
+ },
+ };
+
+ const responseInfo: ToolCallResponseInfo = {
+ callId: request.callId,
+ responsePart: functionResponse,
+ resultDisplay: result.returnDisplay,
+ error: undefined,
+ };
+ updateFunctionResponseUI(responseInfo, ToolCallStatus.Success);
+ setStreamingState(StreamingState.Idle);
+ await submitQuery(functionResponse);
+ } finally {
+ abortControllerRef.current = null;
+ }
+ }
+
+ function declineToolExecution(
+ declineMessage: string,
+ status: ToolCallStatus,
+ ) {
let resultDisplay: ToolResultDisplay | undefined;
if ('fileDiff' in originalConfirmationDetails) {
resultDisplay = {
@@ -426,43 +499,19 @@ export const useGeminiStream = (
functionResponse: {
id: request.callId,
name: request.name,
- response: { error: 'User rejected function call.' },
+ response: { error: declineMessage },
},
};
const responseInfo: ToolCallResponseInfo = {
callId: request.callId,
responsePart: functionResponse,
resultDisplay,
- error: new Error('User rejected function call.'),
- };
- // Update UI to show cancellation/error
- updateFunctionResponseUI(responseInfo, ToolCallStatus.Error);
- setStreamingState(StreamingState.Idle);
- } else {
- const tool = toolRegistry.getTool(request.name);
- if (!tool) {
- throw new Error(
- `Tool "${request.name}" not found or is not registered.`,
- );
- }
- const result = await tool.execute(request.args);
- const functionResponse: Part = {
- functionResponse: {
- name: request.name,
- id: request.callId,
- response: { output: result.llmContent },
- },
+ error: new Error(declineMessage),
};
- const responseInfo: ToolCallResponseInfo = {
- callId: request.callId,
- responsePart: functionResponse,
- resultDisplay: result.returnDisplay,
- error: undefined,
- };
- updateFunctionResponseUI(responseInfo, ToolCallStatus.Success);
+ // Update UI to show cancellation/error
+ updateFunctionResponseUI(responseInfo, status);
setStreamingState(StreamingState.Idle);
- await submitQuery(functionResponse);
}
};
diff --git a/packages/server/src/core/client.ts b/packages/server/src/core/client.ts
index 904e944c..46af465a 100644
--- a/packages/server/src/core/client.ts
+++ b/packages/server/src/core/client.ts
@@ -64,10 +64,13 @@ export class GeminiClient {
.getTool('read_many_files') as ReadManyFilesTool;
if (readManyFilesTool) {
// Read all files in the target directory
- const result = await readManyFilesTool.execute({
- paths: ['**/*'], // Read everything recursively
- useDefaultExcludes: true, // Use default excludes
- });
+ const result = await readManyFilesTool.execute(
+ {
+ paths: ['**/*'], // Read everything recursively
+ useDefaultExcludes: true, // Use default excludes
+ },
+ AbortSignal.timeout(30000),
+ );
if (result.llmContent) {
initialParts.push({
text: `\n--- Full File Context ---\n${result.llmContent}`,
diff --git a/packages/server/src/core/turn.ts b/packages/server/src/core/turn.ts
index 7d8bf7b6..62219938 100644
--- a/packages/server/src/core/turn.ts
+++ b/packages/server/src/core/turn.ts
@@ -36,7 +36,10 @@ export interface ServerTool {
name: string;
schema: FunctionDeclaration;
// The execute method signature might differ slightly or be wrapped
- execute(params: Record<string, unknown>): Promise<ToolResult>;
+ execute(
+ params: Record<string, unknown>,
+ signal?: AbortSignal,
+ ): Promise<ToolResult>;
shouldConfirmExecute(
params: Record<string, unknown>,
): Promise<ToolCallConfirmationDetails | false>;
@@ -153,7 +156,7 @@ export class Turn {
if (confirmationDetails) {
return { ...pendingToolCall, confirmationDetails };
}
- const result = await tool.execute(pendingToolCall.args);
+ const result = await tool.execute(pendingToolCall.args, signal);
return {
...pendingToolCall,
result,
@@ -199,7 +202,11 @@ export class Turn {
resultDisplay: outcome.result?.returnDisplay,
error: outcome.error,
};
- yield { type: GeminiEventType.ToolCallResponse, value: responseInfo };
+
+ // If aborted we're already yielding the user cancellations elsewhere.
+ if (!signal?.aborted) {
+ yield { type: GeminiEventType.ToolCallResponse, value: responseInfo };
+ }
}
}
diff --git a/packages/server/src/tools/edit.ts b/packages/server/src/tools/edit.ts
index c40b9e44..fd57d97d 100644
--- a/packages/server/src/tools/edit.ts
+++ b/packages/server/src/tools/edit.ts
@@ -333,7 +333,10 @@ Expectation for parameters:
* @param params Parameters for the edit operation
* @returns Result of the edit operation
*/
- async execute(params: EditToolParams): Promise<ToolResult> {
+ async execute(
+ params: EditToolParams,
+ _signal: AbortSignal,
+ ): Promise<ToolResult> {
const validationError = this.validateParams(params);
if (validationError) {
return {
diff --git a/packages/server/src/tools/glob.ts b/packages/server/src/tools/glob.ts
index 9e7df0e8..b1b9d0cf 100644
--- a/packages/server/src/tools/glob.ts
+++ b/packages/server/src/tools/glob.ts
@@ -138,7 +138,10 @@ export class GlobTool extends BaseTool<GlobToolParams, ToolResult> {
/**
* Executes the glob search with the given parameters
*/
- async execute(params: GlobToolParams): Promise<ToolResult> {
+ async execute(
+ params: GlobToolParams,
+ _signal: AbortSignal,
+ ): Promise<ToolResult> {
const validationError = this.validateToolParams(params);
if (validationError) {
return {
diff --git a/packages/server/src/tools/grep.ts b/packages/server/src/tools/grep.ts
index e3253ecf..54391832 100644
--- a/packages/server/src/tools/grep.ts
+++ b/packages/server/src/tools/grep.ts
@@ -166,7 +166,10 @@ export class GrepTool extends BaseTool<GrepToolParams, ToolResult> {
* @param params Parameters for the grep search
* @returns Result of the grep search
*/
- async execute(params: GrepToolParams): Promise<ToolResult> {
+ async execute(
+ params: GrepToolParams,
+ _signal: AbortSignal,
+ ): Promise<ToolResult> {
const validationError = this.validateToolParams(params);
if (validationError) {
console.error(
diff --git a/packages/server/src/tools/ls.ts b/packages/server/src/tools/ls.ts
index 01da5121..fea95187 100644
--- a/packages/server/src/tools/ls.ts
+++ b/packages/server/src/tools/ls.ts
@@ -184,7 +184,10 @@ export class LSTool extends BaseTool<LSToolParams, ToolResult> {
* @param params Parameters for the LS operation
* @returns Result of the LS operation
*/
- async execute(params: LSToolParams): Promise<ToolResult> {
+ async execute(
+ params: LSToolParams,
+ _signal: AbortSignal,
+ ): Promise<ToolResult> {
const validationError = this.validateToolParams(params);
if (validationError) {
return this.errorResult(
diff --git a/packages/server/src/tools/read-file.ts b/packages/server/src/tools/read-file.ts
index 598b4691..de09161d 100644
--- a/packages/server/src/tools/read-file.ts
+++ b/packages/server/src/tools/read-file.ts
@@ -193,7 +193,10 @@ export class ReadFileTool extends BaseTool<ReadFileToolParams, ToolResult> {
* @param params Parameters for the file reading
* @returns Result with file contents
*/
- async execute(params: ReadFileToolParams): Promise<ToolResult> {
+ async execute(
+ params: ReadFileToolParams,
+ _signal: AbortSignal,
+ ): Promise<ToolResult> {
const validationError = this.validateToolParams(params);
if (validationError) {
return {
diff --git a/packages/server/src/tools/read-many-files.ts b/packages/server/src/tools/read-many-files.ts
index 0b4b090d..44882e44 100644
--- a/packages/server/src/tools/read-many-files.ts
+++ b/packages/server/src/tools/read-many-files.ts
@@ -237,7 +237,10 @@ Default excludes apply to common non-text files and large dependency directories
return `Will attempt to read and concatenate files ${pathDesc}. ${excludeDesc}. File encoding: ${DEFAULT_ENCODING}. Separator: "${DEFAULT_OUTPUT_SEPARATOR_FORMAT.replace('{filePath}', 'path/to/file.ext')}".`;
}
- async execute(params: ReadManyFilesParams): Promise<ToolResult> {
+ async execute(
+ params: ReadManyFilesParams,
+ _signal: AbortSignal,
+ ): Promise<ToolResult> {
const validationError = this.validateParams(params);
if (validationError) {
return {
diff --git a/packages/server/src/tools/shell.ts b/packages/server/src/tools/shell.ts
index fd8a6b1a..7851b76a 100644
--- a/packages/server/src/tools/shell.ts
+++ b/packages/server/src/tools/shell.ts
@@ -118,7 +118,10 @@ export class ShellTool extends BaseTool<ShellToolParams, ToolResult> {
return confirmationDetails;
}
- async execute(params: ShellToolParams): Promise<ToolResult> {
+ async execute(
+ params: ShellToolParams,
+ abortSignal: AbortSignal,
+ ): Promise<ToolResult> {
const validationError = this.validateToolParams(params);
if (validationError) {
return {
@@ -174,18 +177,38 @@ export class ShellTool extends BaseTool<ShellToolParams, ToolResult> {
});
let code: number | null = null;
- let signal: NodeJS.Signals | null = null;
- shell.on(
- 'close',
- (_code: number | null, _signal: NodeJS.Signals | null) => {
- code = _code;
- signal = _signal;
- },
- );
+ let processSignal: NodeJS.Signals | null = null;
+ const closeHandler = (
+ _code: number | null,
+ _signal: NodeJS.Signals | null,
+ ) => {
+ code = _code;
+ processSignal = _signal;
+ };
+ shell.on('close', closeHandler);
+
+ const abortHandler = () => {
+ if (shell.pid) {
+ try {
+ // Kill the entire process group
+ process.kill(-shell.pid, 'SIGTERM');
+ } catch (_e) {
+ // Fallback to killing the main process if group kill fails
+ try {
+ shell.kill('SIGKILL'); // or 'SIGTERM'
+ } catch (_killError) {
+ // Ignore errors if the process is already dead
+ }
+ }
+ }
+ };
+ abortSignal.addEventListener('abort', abortHandler);
// wait for the shell to exit
await new Promise((resolve) => shell.on('close', resolve));
+ abortSignal.removeEventListener('abort', abortHandler);
+
// parse pids (pgrep output) from temporary file and remove it
const backgroundPIDs: number[] = [];
if (fs.existsSync(tempFilePath)) {
@@ -205,19 +228,26 @@ export class ShellTool extends BaseTool<ShellToolParams, ToolResult> {
}
fs.unlinkSync(tempFilePath);
} else {
- console.error('missing pgrep output');
+ if (!abortSignal.aborted) {
+ console.error('missing pgrep output');
+ }
}
- const llmContent = [
- `Command: ${params.command}`,
- `Directory: ${params.directory || '(root)'}`,
- `Stdout: ${stdout || '(empty)'}`,
- `Stderr: ${stderr || '(empty)'}`,
- `Error: ${error ?? '(none)'}`,
- `Exit Code: ${code ?? '(none)'}`,
- `Signal: ${signal ?? '(none)'}`,
- `Background PIDs: ${backgroundPIDs.length ? backgroundPIDs.join(', ') : '(none)'}`,
- ].join('\n');
+ let llmContent = '';
+ if (abortSignal.aborted) {
+ llmContent = 'Command did not complete, it was cancelled by the user';
+ } else {
+ llmContent = [
+ `Command: ${params.command}`,
+ `Directory: ${params.directory || '(root)'}`,
+ `Stdout: ${stdout || '(empty)'}`,
+ `Stderr: ${stderr || '(empty)'}`,
+ `Error: ${error ?? '(none)'}`,
+ `Exit Code: ${code ?? '(none)'}`,
+ `Signal: ${processSignal ?? '(none)'}`,
+ `Background PIDs: ${backgroundPIDs.length ? backgroundPIDs.join(', ') : '(none)'}`,
+ ].join('\n');
+ }
const returnDisplay = this.config.getDebugMode() ? llmContent : output;
diff --git a/packages/server/src/tools/terminal.ts b/packages/server/src/tools/terminal.ts
index 7320cfb2..af558fb0 100644
--- a/packages/server/src/tools/terminal.ts
+++ b/packages/server/src/tools/terminal.ts
@@ -265,7 +265,10 @@ Use this tool for running build steps (\`npm install\`, \`make\`), linters (\`es
return confirmationDetails;
}
- async execute(params: TerminalToolParams): Promise<ToolResult> {
+ async execute(
+ params: TerminalToolParams,
+ _signal: AbortSignal,
+ ): Promise<ToolResult> {
const validationError = this.validateToolParams(params);
if (validationError) {
return {
diff --git a/packages/server/src/tools/tools.ts b/packages/server/src/tools/tools.ts
index ac04450d..7bb05a95 100644
--- a/packages/server/src/tools/tools.ts
+++ b/packages/server/src/tools/tools.ts
@@ -64,7 +64,7 @@ export interface Tool<
* @param params Parameters for the tool execution
* @returns Result of the tool execution
*/
- execute(params: TParams): Promise<TResult>;
+ execute(params: TParams, signal: AbortSignal): Promise<TResult>;
}
/**
@@ -141,9 +141,10 @@ export abstract class BaseTool<
* Abstract method to execute the tool with the given parameters
* Must be implemented by derived classes
* @param params Parameters for the tool execution
+ * @param signal AbortSignal for tool cancellation
* @returns Result of the tool execution
*/
- abstract execute(params: TParams): Promise<TResult>;
+ abstract execute(params: TParams, signal: AbortSignal): Promise<TResult>;
}
export interface ToolResult {
diff --git a/packages/server/src/tools/web-fetch.ts b/packages/server/src/tools/web-fetch.ts
index 12584231..62ca2162 100644
--- a/packages/server/src/tools/web-fetch.ts
+++ b/packages/server/src/tools/web-fetch.ts
@@ -70,7 +70,10 @@ export class WebFetchTool extends BaseTool<WebFetchToolParams, ToolResult> {
return `Fetching content from ${displayUrl}`;
}
- async execute(params: WebFetchToolParams): Promise<ToolResult> {
+ async execute(
+ params: WebFetchToolParams,
+ _signal: AbortSignal,
+ ): Promise<ToolResult> {
const validationError = this.validateParams(params);
if (validationError) {
return {
diff --git a/packages/server/src/tools/write-file.ts b/packages/server/src/tools/write-file.ts
index c9a47296..1f4c0d94 100644
--- a/packages/server/src/tools/write-file.ts
+++ b/packages/server/src/tools/write-file.ts
@@ -150,7 +150,10 @@ export class WriteFileTool extends BaseTool<WriteFileToolParams, ToolResult> {
return confirmationDetails;
}
- async execute(params: WriteFileToolParams): Promise<ToolResult> {
+ async execute(
+ params: WriteFileToolParams,
+ _signal: AbortSignal,
+ ): Promise<ToolResult> {
const validationError = this.validateParams(params);
if (validationError) {
return {