Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions js/ai/src/prompt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,12 @@ function definePromptAsync<
...renderOptions?.config,
},
});

// Fix for issue #3348: Preserve AbortSignal object
// AbortSignal needs its prototype chain and shouldn't be processed by stripUndefinedProps
if (renderOptions?.abortSignal) {
opts.abortSignal = renderOptions.abortSignal;
}
// if config is empty and it was not explicitly passed in, we delete it, don't want {}
if (Object.keys(opts.config).length === 0 && !renderOptions?.config) {
delete opts.config;
Expand Down
64 changes: 64 additions & 0 deletions js/ai/tests/prompt/prompt_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import assert from 'node:assert';
import { beforeEach, describe, it } from 'node:test';
import { Document } from '../../src/document.js';
import type { GenerateOptions } from '../../src/index.js';
import { defineModel } from '../../src/model.js';
import {
definePrompt,
type PromptConfig,
Expand All @@ -45,6 +46,26 @@ describe('prompt', () => {
},
async () => 'a'
);

// Define a special model to test AbortSignal preservation
defineModel(
registry,
{
name: 'abortTestModel',
apiVersion: 'v2',
},
async (request, opts) => {
// Store the abortSignal for verification
(defineModel as any).__test__lastAbortSignal = request.abortSignal;
return {
message: {
role: 'model',
content: [{ text: 'AbortSignal preserved correctly' }],
},
finishReason: 'stop',
};
}
);
});

let basicTests: {
Expand Down Expand Up @@ -821,6 +842,49 @@ describe('prompt', () => {
toJsonSchema({ schema: schema1 })
);
});

it('preserves AbortSignal objects through the rendering pipeline', async () => {
// Test AbortSignal.timeout()
const timeoutSignal = AbortSignal.timeout(1000);
const prompt = definePrompt(registry, {
name: 'abortTestPrompt',
model: 'abortTestModel',
prompt: 'test message',
});

const rendered = await prompt.render(undefined, {
abortSignal: timeoutSignal,
});

// Verify the AbortSignal is preserved in the rendered options
assert.ok(rendered.abortSignal, 'AbortSignal should be preserved');
assert.strictEqual(
rendered.abortSignal,
timeoutSignal,
'Should be the exact same AbortSignal instance'
);
assert.ok(
rendered.abortSignal instanceof AbortSignal,
'Should be an AbortSignal instance'
);

// Test manual AbortController
const controller = new AbortController();
const rendered2 = await prompt.render(undefined, {
abortSignal: controller.signal,
});

assert.ok(rendered2.abortSignal, 'Manual AbortSignal should be preserved');
assert.strictEqual(
rendered2.abortSignal,
controller.signal,
'Should be the exact same manual AbortSignal instance'
);
assert.ok(
rendered2.abortSignal instanceof AbortSignal,
'Manual AbortSignal should be an AbortSignal instance'
);
});
});

function stripUndefined(input: any) {
Expand Down