diff --git a/js/genkit/src/genkit.ts b/js/genkit/src/genkit.ts index f2892a4fa6..0f40fab8d0 100644 --- a/js/genkit/src/genkit.ts +++ b/js/genkit/src/genkit.ts @@ -945,11 +945,16 @@ export class Genkit implements HasRegistry { }, async listActions() { if (typeof plugin.list === 'function') { - return (await plugin.list()).map((a) => ({ - ...a, - // Apply namespace for v2 plugins. - name: `${plugin.name}/${a.name}`, - })); + return (await plugin.list()).map((a) => { + if (a.name.startsWith(`${plugin.name}/`)) { + return a; + } + return { + ...a, + // Apply namespace for v2 plugins. + name: `${plugin.name}/${a.name}`, + }; + }); } return []; }, diff --git a/js/plugins/compat-oai/src/audio.ts b/js/plugins/compat-oai/src/audio.ts index 041491b793..149054242d 100644 --- a/js/plugins/compat-oai/src/audio.ts +++ b/js/plugins/compat-oai/src/audio.ts @@ -17,11 +17,11 @@ import type { GenerateRequest, GenerateResponseData, - Genkit, ModelReference, } from 'genkit'; import { GenerationCommonConfigSchema, Message, modelRef, z } from 'genkit'; import type { ModelAction, ModelInfo } from 'genkit/model'; +import { model } from 'genkit/plugin'; import type OpenAI from 'openai'; import { Response } from 'openai/core.mjs'; import type { @@ -181,19 +181,17 @@ async function toGenerateResponse( export function defineCompatOpenAISpeechModel< CustomOptions extends z.ZodTypeAny = z.ZodTypeAny, >(params: { - ai: Genkit; name: string; client: OpenAI; modelRef?: ModelReference; requestBuilder?: SpeechRequestBuilder; }): ModelAction { - const { ai, name, client, modelRef, requestBuilder } = params; + const { name, client, modelRef, requestBuilder } = params; const modelName = name.substring(name.indexOf('/') + 1); - return ai.defineModel( + return model( { name, - apiVersion: 'v2', ...modelRef?.info, configSchema: modelRef?.configSchema, }, @@ -335,18 +333,16 @@ function transcriptionToGenerateResponse( export function defineCompatOpenAITranscriptionModel< CustomOptions extends z.ZodTypeAny = z.ZodTypeAny, >(params: { - ai: Genkit; name: string; client: OpenAI; modelRef?: ModelReference; requestBuilder?: TranscriptionRequestBuilder; }): ModelAction { - const { ai, name, client, modelRef, requestBuilder } = params; + const { name, client, modelRef, requestBuilder } = params; - return ai.defineModel( + return model( { name, - apiVersion: 'v2', ...modelRef?.info, configSchema: modelRef?.configSchema, }, diff --git a/js/plugins/compat-oai/src/deepseek/index.ts b/js/plugins/compat-oai/src/deepseek/index.ts index ae68f90887..126af7a58e 100644 --- a/js/plugins/compat-oai/src/deepseek/index.ts +++ b/js/plugins/compat-oai/src/deepseek/index.ts @@ -16,14 +16,13 @@ import { ActionMetadata, - Genkit, GenkitError, modelActionMetadata, ModelReference, z, } from 'genkit'; import { logger } from 'genkit/logging'; -import { GenkitPlugin } from 'genkit/plugin'; +import { GenkitPluginV2 } from 'genkit/plugin'; import { ActionType } from 'genkit/registry'; import OpenAI from 'openai'; import { openAICompatible, PluginOptions } from '../index.js'; @@ -38,17 +37,15 @@ import { export type DeepSeekPluginOptions = Omit; const resolver = async ( - ai: Genkit, client: OpenAI, actionType: ActionType, actionName: string ) => { if (actionType === 'model') { const modelRef = deepSeekModelRef({ - name: `deepseek/${actionName}`, + name: actionName, }); - defineCompatOpenAIModel({ - ai, + return defineCompatOpenAIModel({ name: modelRef.name, client, modelRef, @@ -56,6 +53,7 @@ const resolver = async ( }); } else { logger.warn('Only model actions are supported by the DeepSeek plugin'); + return undefined; } }; @@ -67,7 +65,7 @@ const listActions = async (client: OpenAI): Promise => { const modelRef = SUPPORTED_DEEPSEEK_MODELS[model.id] ?? deepSeekModelRef({ - name: `deepseek/${model.id}`, + name: model.id, }); return modelActionMetadata({ name: modelRef.name, @@ -78,7 +76,9 @@ const listActions = async (client: OpenAI): Promise => { ); }; -export function deepSeekPlugin(options?: DeepSeekPluginOptions): GenkitPlugin { +export function deepSeekPlugin( + options?: DeepSeekPluginOptions +): GenkitPluginV2 { const apiKey = options?.apiKey ?? process.env.DEEPSEEK_API_KEY; if (!apiKey) { throw new GenkitError({ @@ -92,10 +92,9 @@ export function deepSeekPlugin(options?: DeepSeekPluginOptions): GenkitPlugin { baseURL: 'https://api.deepseek.com', apiKey, ...options, - initializer: async (ai, client) => { - Object.values(SUPPORTED_DEEPSEEK_MODELS).forEach((modelRef) => + initializer: async (client) => { + return Object.values(SUPPORTED_DEEPSEEK_MODELS).map((modelRef) => defineCompatOpenAIModel({ - ai, name: modelRef.name, client, modelRef, @@ -109,7 +108,7 @@ export function deepSeekPlugin(options?: DeepSeekPluginOptions): GenkitPlugin { } export type DeepSeekPlugin = { - (params?: DeepSeekPluginOptions): GenkitPlugin; + (params?: DeepSeekPluginOptions): GenkitPluginV2; model( name: keyof typeof SUPPORTED_DEEPSEEK_MODELS, config?: z.infer @@ -119,7 +118,7 @@ export type DeepSeekPlugin = { const model = ((name: string, config?: any): ModelReference => { return deepSeekModelRef({ - name: `deepseek/${name}`, + name, config, }); }) as DeepSeekPlugin['model']; diff --git a/js/plugins/compat-oai/src/embedder.ts b/js/plugins/compat-oai/src/embedder.ts index 2fab7e86fe..f4499eccad 100644 --- a/js/plugins/compat-oai/src/embedder.ts +++ b/js/plugins/compat-oai/src/embedder.ts @@ -17,7 +17,8 @@ // import { defineEmbedder, embedderRef } from '@genkit-ai/ai/embedder'; -import type { EmbedderAction, EmbedderReference, Genkit } from 'genkit'; +import type { EmbedderAction, EmbedderReference } from 'genkit'; +import { embedder } from 'genkit/plugin'; import OpenAI from 'openai'; /** @@ -34,25 +35,24 @@ import OpenAI from 'openai'; * @returns the created {@link EmbedderAction} */ export function defineCompatOpenAIEmbedder(params: { - ai: Genkit; name: string; client: OpenAI; embedderRef?: EmbedderReference; }): EmbedderAction { - const { ai, name, client, embedderRef } = params; + const { name, client, embedderRef } = params; const modelName = name.substring(name.indexOf('/') + 1); - return ai.defineEmbedder( + return embedder( { name, configSchema: embedderRef?.configSchema, ...embedderRef?.info, }, - async (input, options) => { - const { encodingFormat: encoding_format, ...restOfConfig } = options; + async (req) => { + const { encodingFormat: encoding_format, ...restOfConfig } = req.options; const embeddings = await client.embeddings.create({ model: modelName!, - input: input.map((d) => d.text), + input: req.input.map((d) => d.text), encoding_format, ...restOfConfig, }); diff --git a/js/plugins/compat-oai/src/image.ts b/js/plugins/compat-oai/src/image.ts index ee195eac45..1754e5e0b0 100644 --- a/js/plugins/compat-oai/src/image.ts +++ b/js/plugins/compat-oai/src/image.ts @@ -17,11 +17,11 @@ import type { GenerateRequest, GenerateResponseData, - Genkit, ModelReference, } from 'genkit'; import { Message, modelRef, z } from 'genkit'; import { ModelAction, ModelInfo } from 'genkit/model'; +import { model } from 'genkit/plugin'; import OpenAI from 'openai'; import type { ImageGenerateParams, @@ -120,20 +120,18 @@ function toGenerateResponse(result: ImagesResponse): GenerateResponseData { export function defineCompatOpenAIImageModel< CustomOptions extends z.ZodTypeAny = z.ZodTypeAny, >(params: { - ai: Genkit; name: string; client: OpenAI; modelRef?: ModelReference; requestBuilder?: ImageRequestBuilder; }): ModelAction { - const { ai, name, client, modelRef, requestBuilder } = params; + const { name, client, modelRef, requestBuilder } = params; const modelName = name.substring(name.indexOf('/') + 1); - return ai.defineModel( + return model( { name, ...modelRef?.info, - apiVersion: 'v2', configSchema: modelRef?.configSchema, }, async (request, { abortSignal }) => { diff --git a/js/plugins/compat-oai/src/index.ts b/js/plugins/compat-oai/src/index.ts index e5a40023f2..1c5a5d5ec1 100644 --- a/js/plugins/compat-oai/src/index.ts +++ b/js/plugins/compat-oai/src/index.ts @@ -14,10 +14,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -import { ActionMetadata, type Genkit } from 'genkit'; -import { genkitPlugin } from 'genkit/plugin'; +import { ActionMetadata } from 'genkit'; +import { ResolvableAction, genkitPluginV2 } from 'genkit/plugin'; import { ActionType } from 'genkit/registry'; import { OpenAI, type ClientOptions } from 'openai'; +import { compatOaiModelRef, defineCompatOpenAIModel } from './model.js'; export { SpeechConfigSchema, @@ -46,13 +47,12 @@ export { export interface PluginOptions extends Partial { name: string; - initializer?: (ai: Genkit, client: OpenAI) => Promise; + initializer?: (client: OpenAI) => Promise; resolver?: ( - ai: Genkit, client: OpenAI, actionType: ActionType, actionName: string - ) => Promise; + ) => Promise | ResolvableAction | undefined; listActions?: (client: OpenAI) => Promise; } @@ -110,29 +110,41 @@ export interface PluginOptions extends Partial { */ export const openAICompatible = (options: PluginOptions) => { let listActionsCache; - return genkitPlugin( - options.name, - async (ai: Genkit) => { - if (options.initializer) { - const client = new OpenAI(options); - await options.initializer(ai, client); + return genkitPluginV2({ + name: options.name, + async init() { + if (!options.initializer) { + return []; } + const client = new OpenAI(options); + return await options.initializer(client); }, - async (ai: Genkit, actionType: ActionType, actionName: string) => { + async resolve(actionType: ActionType, actionName: string) { + const client = new OpenAI(options); if (options.resolver) { - const client = new OpenAI(options); - await options.resolver(ai, client, actionType, actionName); + return await options.resolver(client, actionType, actionName); + } else { + if (actionType === 'model') { + return defineCompatOpenAIModel({ + name: actionName, + client, + modelRef: compatOaiModelRef({ + name: actionName, + }), + }); + } + return undefined; } }, - options.listActions + list: options.listActions ? async () => { if (listActionsCache) return listActionsCache; const client = new OpenAI(options); listActionsCache = await options.listActions!(client); return listActionsCache; } - : undefined - ); + : undefined, + }); }; export default openAICompatible; diff --git a/js/plugins/compat-oai/src/model.ts b/js/plugins/compat-oai/src/model.ts index 4300e630ca..09fde93026 100644 --- a/js/plugins/compat-oai/src/model.ts +++ b/js/plugins/compat-oai/src/model.ts @@ -19,7 +19,6 @@ import type { GenerateRequest, GenerateResponseChunkData, GenerateResponseData, - Genkit, MessageData, ModelReference, Part, @@ -29,6 +28,7 @@ import type { } from 'genkit'; import { GenerationCommonConfigSchema, Message, modelRef, z } from 'genkit'; import type { ModelAction, ModelInfo, ToolDefinition } from 'genkit/model'; +import { model } from 'genkit/plugin'; import type OpenAI from 'openai'; import type { ChatCompletion, @@ -485,19 +485,17 @@ export function openAIModelRunner( export function defineCompatOpenAIModel< CustomOptions extends z.ZodTypeAny = z.ZodTypeAny, >(params: { - ai: Genkit; name: string; client: OpenAI; modelRef?: ModelReference; requestBuilder?: ModelRequestBuilder; }): ModelAction { - const { ai, name, client, modelRef, requestBuilder } = params; + const { name, client, modelRef, requestBuilder } = params; const modelName = name.substring(name.indexOf('/') + 1); - return ai.defineModel( + return model( { name, - apiVersion: 'v2', ...modelRef?.info, configSchema: modelRef?.configSchema, }, diff --git a/js/plugins/compat-oai/src/openai/index.ts b/js/plugins/compat-oai/src/openai/index.ts index 7709fd6f9a..a9358602ee 100644 --- a/js/plugins/compat-oai/src/openai/index.ts +++ b/js/plugins/compat-oai/src/openai/index.ts @@ -20,12 +20,11 @@ import { embedderActionMetadata, embedderRef, EmbedderReference, - Genkit, modelActionMetadata, ModelReference, z, } from 'genkit'; -import { GenkitPlugin } from 'genkit/plugin'; +import { GenkitPluginV2, ResolvableAction } from 'genkit/plugin'; import { ActionType } from 'genkit/registry'; import OpenAI from 'openai'; import { @@ -66,23 +65,25 @@ export type OpenAIPluginOptions = Omit; const UNSUPPORTED_MODEL_MATCHERS = ['babbage', 'davinci', 'codex']; const resolver = async ( - ai: Genkit, client: OpenAI, actionType: ActionType, actionName: string ) => { if (actionType === 'embedder') { - defineCompatOpenAIEmbedder({ ai, name: `openai/${actionName}`, client }); + return defineCompatOpenAIEmbedder({ name: actionName, client }); } else if ( actionName.includes('gpt-image-1') || actionName.includes('dall-e') ) { - const modelRef = openAIImageModelRef({ name: `openai/${actionName}` }); - defineCompatOpenAIImageModel({ ai, name: modelRef.name, client, modelRef }); + const modelRef = openAIImageModelRef({ name: actionName }); + return defineCompatOpenAIImageModel({ + name: modelRef.name, + client, + modelRef, + }); } else if (actionName.includes('tts')) { - const modelRef = openAISpeechModelRef({ name: `openai/${actionName}` }); - defineCompatOpenAISpeechModel({ - ai, + const modelRef = openAISpeechModelRef({ name: actionName }); + return defineCompatOpenAISpeechModel({ name: modelRef.name, client, modelRef, @@ -92,18 +93,16 @@ const resolver = async ( actionName.includes('transcribe') ) { const modelRef = openAITranscriptionModelRef({ - name: `openai/${actionName}`, + name: actionName, }); - defineCompatOpenAITranscriptionModel({ - ai, + return defineCompatOpenAITranscriptionModel({ name: modelRef.name, client, modelRef, }); } else { - const modelRef = openAIModelRef({ name: `openai/${actionName}` }); - defineCompatOpenAIModel({ - ai, + const modelRef = openAIModelRef({ name: actionName }); + return defineCompatOpenAIModel({ name: modelRef.name, client, modelRef, @@ -120,7 +119,7 @@ const listActions = async (client: OpenAI): Promise => { response.data.filter(filterOpenAiModels).map((model: OpenAI.Model) => { if (model.id.includes('embedding')) { return embedderActionMetadata({ - name: `openai/${model.id}`, + name: model.id, configSchema: TextEmbeddingConfigSchema, info: SUPPORTED_EMBEDDING_MODELS[model.id]?.info, }); @@ -130,7 +129,7 @@ const listActions = async (client: OpenAI): Promise => { ) { const modelRef = SUPPORTED_IMAGE_MODELS[model.id] ?? - openAIImageModelRef({ name: `openai/${model.id}` }); + openAIImageModelRef({ name: model.id }); return modelActionMetadata({ name: modelRef.name, info: modelRef.info, @@ -139,7 +138,7 @@ const listActions = async (client: OpenAI): Promise => { } else if (model.id.includes('tts')) { const modelRef = SUPPORTED_TTS_MODELS[model.id] ?? - openAISpeechModelRef({ name: `openai/${model.id}` }); + openAISpeechModelRef({ name: model.id }); return modelActionMetadata({ name: modelRef.name, info: modelRef.info, @@ -151,7 +150,7 @@ const listActions = async (client: OpenAI): Promise => { ) { const modelRef = SUPPORTED_STT_MODELS[model.id] ?? - openAITranscriptionModelRef({ name: `openai/${model.id}` }); + openAITranscriptionModelRef({ name: model.id }); return modelActionMetadata({ name: modelRef.name, info: modelRef.info, @@ -159,8 +158,7 @@ const listActions = async (client: OpenAI): Promise => { }); } else { const modelRef = - SUPPORTED_GPT_MODELS[model.id] ?? - openAIModelRef({ name: `openai/${model.id}` }); + SUPPORTED_GPT_MODELS[model.id] ?? openAIModelRef({ name: model.id }); return modelActionMetadata({ name: modelRef.name, info: modelRef.info, @@ -171,49 +169,57 @@ const listActions = async (client: OpenAI): Promise => { ); }; -export function openAIPlugin(options?: OpenAIPluginOptions): GenkitPlugin { +export function openAIPlugin(options?: OpenAIPluginOptions): GenkitPluginV2 { return openAICompatible({ name: 'openai', ...options, - initializer: async (ai, client) => { - Object.values(SUPPORTED_GPT_MODELS).forEach((modelRef) => - defineCompatOpenAIModel({ ai, name: modelRef.name, client, modelRef }) + initializer: async (client) => { + const models = [] as ResolvableAction[]; + models.push( + ...Object.values(SUPPORTED_GPT_MODELS).map((modelRef) => + defineCompatOpenAIModel({ name: modelRef.name, client, modelRef }) + ) ); - Object.values(SUPPORTED_EMBEDDING_MODELS).forEach((embedderRef) => - defineCompatOpenAIEmbedder({ - ai, - name: embedderRef.name, - client, - embedderRef, - }) + models.push( + ...Object.values(SUPPORTED_EMBEDDING_MODELS).map((embedderRef) => + defineCompatOpenAIEmbedder({ + name: embedderRef.name, + client, + embedderRef, + }) + ) ); - Object.values(SUPPORTED_TTS_MODELS).forEach((modelRef) => - defineCompatOpenAISpeechModel({ - ai, - name: modelRef.name, - client, - modelRef, - }) + models.push( + ...Object.values(SUPPORTED_TTS_MODELS).map((modelRef) => + defineCompatOpenAISpeechModel({ + name: modelRef.name, + client, + modelRef, + }) + ) ); - Object.values(SUPPORTED_STT_MODELS).forEach((modelRef) => - defineCompatOpenAITranscriptionModel({ - ai, - name: modelRef.name, - client, - modelRef, - }) + models.push( + ...Object.values(SUPPORTED_STT_MODELS).map((modelRef) => + defineCompatOpenAITranscriptionModel({ + name: modelRef.name, + client, + modelRef, + }) + ) ); - Object.values(SUPPORTED_IMAGE_MODELS).forEach((modelRef) => - defineCompatOpenAIImageModel({ - ai, - name: modelRef.name, - client, - modelRef, - requestBuilder: modelRef.name.includes('gpt-image-1') - ? gptImage1RequestBuilder - : undefined, - }) + models.push( + ...Object.values(SUPPORTED_IMAGE_MODELS).map((modelRef) => + defineCompatOpenAIImageModel({ + name: modelRef.name, + client, + modelRef, + requestBuilder: modelRef.name.includes('gpt-image-1') + ? gptImage1RequestBuilder + : undefined, + }) + ) ); + return models; }, resolver, listActions, @@ -221,7 +227,7 @@ export function openAIPlugin(options?: OpenAIPluginOptions): GenkitPlugin { } export type OpenAIPlugin = { - (params?: OpenAIPluginOptions): GenkitPlugin; + (params?: OpenAIPluginOptions): GenkitPluginV2; model( name: | keyof typeof SUPPORTED_GPT_MODELS @@ -263,24 +269,24 @@ export type OpenAIPlugin = { const model = ((name: string, config?: any): ModelReference => { if (name.includes('gpt-image-1') || name.includes('dall-e')) { return openAIImageModelRef({ - name: `openai/${name}`, + name, config, }); } if (name.includes('tts')) { return openAISpeechModelRef({ - name: `openai/${name}`, + name, config, }); } if (name.includes('whisper') || name.includes('transcribe')) { return openAITranscriptionModelRef({ - name: `openai/${name}`, + name, config, }); } return openAIModelRef({ - name: `openai/${name}`, + name, config, }); }) as OpenAIPlugin['model']; @@ -290,7 +296,7 @@ const embedder = (( config?: any ): EmbedderReference => { return embedderRef({ - name: `openai/${name}`, + name, config, configSchema: TextEmbeddingConfigSchema, }); diff --git a/js/plugins/compat-oai/src/xai/index.ts b/js/plugins/compat-oai/src/xai/index.ts index 3bd72534c4..8a4bf6a4b3 100644 --- a/js/plugins/compat-oai/src/xai/index.ts +++ b/js/plugins/compat-oai/src/xai/index.ts @@ -16,14 +16,13 @@ import { ActionMetadata, - Genkit, GenkitError, modelActionMetadata, ModelReference, z, } from 'genkit'; import { logger } from 'genkit/logging'; -import { GenkitPlugin } from 'genkit/plugin'; +import { GenkitPluginV2, ResolvableAction } from 'genkit/plugin'; import { ActionType } from 'genkit/registry'; import OpenAI from 'openai'; import { @@ -43,15 +42,13 @@ import { export type XAIPluginOptions = Omit; const resolver = async ( - ai: Genkit, client: OpenAI, actionType: ActionType, actionName: string ) => { if (actionType === 'model') { - const modelRef = xaiModelRef({ name: `xai/${actionName}` }); - defineCompatOpenAIModel({ - ai, + const modelRef = xaiModelRef({ name: actionName }); + return defineCompatOpenAIModel({ name: modelRef.name, client, modelRef, @@ -60,6 +57,7 @@ const resolver = async ( } else { logger.warn('Only model actions are supported by the XAI plugin'); } + return undefined; }; const listActions = async (client: OpenAI): Promise => { @@ -70,7 +68,7 @@ const listActions = async (client: OpenAI): Promise => { if (model.id.includes('image')) { const modelRef = SUPPORTED_IMAGE_MODELS[model.id] ?? - xaiImageModelRef({ name: `xai/${model.id}` }); + xaiImageModelRef({ name: model.id }); return modelActionMetadata({ name: modelRef.name, info: modelRef.info, @@ -79,7 +77,7 @@ const listActions = async (client: OpenAI): Promise => { } else { const modelRef = SUPPORTED_LANGUAGE_MODELS[model.id] ?? - xaiModelRef({ name: `xai/${model.id}` }); + xaiModelRef({ name: model.id }); return modelActionMetadata({ name: modelRef.name, info: modelRef.info, @@ -90,7 +88,7 @@ const listActions = async (client: OpenAI): Promise => { ); }; -export function xAIPlugin(options?: XAIPluginOptions): GenkitPlugin { +export function xAIPlugin(options?: XAIPluginOptions): GenkitPluginV2 { const apiKey = options?.apiKey ?? process.env.XAI_API_KEY; if (!apiKey) { throw new GenkitError({ @@ -104,24 +102,28 @@ export function xAIPlugin(options?: XAIPluginOptions): GenkitPlugin { baseURL: 'https://api.x.ai/v1', apiKey, ...options, - initializer: async (ai, client) => { - Object.values(SUPPORTED_LANGUAGE_MODELS).forEach((modelRef) => - defineCompatOpenAIModel({ - ai, - name: modelRef.name, - client, - modelRef, - requestBuilder: grokRequestBuilder, - }) + initializer: async (client) => { + const models = [] as ResolvableAction[]; + models.push( + ...Object.values(SUPPORTED_LANGUAGE_MODELS).map((modelRef) => + defineCompatOpenAIModel({ + name: modelRef.name, + client, + modelRef, + requestBuilder: grokRequestBuilder, + }) + ) ); - Object.values(SUPPORTED_IMAGE_MODELS).forEach((modelRef) => - defineCompatOpenAIImageModel({ - ai, - name: modelRef.name, - client, - modelRef, - }) + models.push( + ...Object.values(SUPPORTED_IMAGE_MODELS).map((modelRef) => + defineCompatOpenAIImageModel({ + name: modelRef.name, + client, + modelRef, + }) + ) ); + return models; }, resolver, listActions, @@ -129,7 +131,7 @@ export function xAIPlugin(options?: XAIPluginOptions): GenkitPlugin { } export type XAIPlugin = { - (params?: XAIPluginOptions): GenkitPlugin; + (params?: XAIPluginOptions): GenkitPluginV2; model( name: keyof typeof SUPPORTED_LANGUAGE_MODELS, config?: z.infer @@ -144,12 +146,12 @@ export type XAIPlugin = { const model = ((name: string, config?: any): ModelReference => { if (name.includes('image')) { return xaiImageModelRef({ - name: `xai/${name}`, + name, config, }); } return xaiModelRef({ - name: `xai/${name}`, + name, config, }); }) as XAIPlugin['model']; diff --git a/js/plugins/compat-oai/tests/openai_test.ts b/js/plugins/compat-oai/tests/openai_test.ts index 1653a07ea3..a5ee506fec 100644 --- a/js/plugins/compat-oai/tests/openai_test.ts +++ b/js/plugins/compat-oai/tests/openai_test.ts @@ -15,15 +15,8 @@ * limitations under the License. */ -import { - afterEach, - beforeEach, - describe, - expect, - it, - jest, -} from '@jest/globals'; -import { modelRef, type GenerateRequest, type Genkit } from 'genkit'; +import { afterEach, describe, expect, it, jest } from '@jest/globals'; +import { modelRef, type GenerateRequest } from 'genkit/model'; import type OpenAI from 'openai'; import { @@ -32,116 +25,89 @@ import { toOpenAIRequestBody, } from '../src/model'; -jest.mock('@genkit-ai/ai/model', () => ({ - ...(jest.requireActual('@genkit-ai/ai/model') as Record), - defineModel: jest.fn(), -})); - describe('gptModel', () => { - let ai: Genkit; - - beforeEach(() => { - ai = { - defineModel: jest.fn(), - } as unknown as Genkit; - }); - afterEach(() => { jest.clearAllMocks(); }); it('should correctly define supported GPT models', () => { - jest.spyOn(ai, 'defineModel').mockImplementation((() => ({})) as any); - defineCompatOpenAIModel({ - ai, + const model = defineCompatOpenAIModel({ name: 'openai/gpt-4o', client: {} as OpenAI, modelRef: testModelRef('openai/gpt-4o'), }); - expect(ai.defineModel).toHaveBeenCalledWith( - { - name: 'openai/gpt-4o', - supports: { - multiturn: true, - tools: true, - media: true, - systemRole: true, - output: ['text', 'json'], - }, - configSchema: ChatCompletionCommonConfigSchema, - apiVersion: 'v2', + expect({ + name: model.__action.name, + supports: model.__action.metadata?.model.supports, + }).toStrictEqual({ + name: 'openai/gpt-4o', + supports: { + multiturn: true, + tools: true, + media: true, + systemRole: true, + output: ['text', 'json'], }, - expect.any(Function) - ); + }); }); it('should correctly define gpt-4.1, gpt-4.1-mini, and gpt-4.1-nano', () => { - jest.spyOn(ai, 'defineModel').mockImplementation((() => ({})) as any); - defineCompatOpenAIModel({ - ai, + const gpt41 = defineCompatOpenAIModel({ name: 'openai/gpt-4.1', client: {} as OpenAI, modelRef: testModelRef('openai/gpt-4.1'), }); - expect(ai.defineModel).toHaveBeenCalledWith( - { - name: 'openai/gpt-4.1', - supports: { - multiturn: true, - tools: true, - media: true, - systemRole: true, - output: ['text', 'json'], - }, - configSchema: ChatCompletionCommonConfigSchema, - apiVersion: 'v2', + expect({ + name: gpt41.__action.name, + supports: gpt41.__action.metadata?.model.supports, + }).toStrictEqual({ + name: 'openai/gpt-4.1', + supports: { + multiturn: true, + tools: true, + media: true, + systemRole: true, + output: ['text', 'json'], }, - expect.any(Function) - ); + }); - defineCompatOpenAIModel({ - ai, + const gpt41mini = defineCompatOpenAIModel({ name: 'openai/gpt-4.1-mini', client: {} as OpenAI, modelRef: testModelRef('openai/gpt-4.1-mini'), }); - expect(ai.defineModel).toHaveBeenCalledWith( - { - name: 'openai/gpt-4.1-mini', - supports: { - multiturn: true, - tools: true, - media: true, - systemRole: true, - output: ['text', 'json'], - }, - configSchema: ChatCompletionCommonConfigSchema, - apiVersion: 'v2', + expect({ + name: gpt41mini.__action.name, + supports: gpt41mini.__action.metadata?.model.supports, + }).toStrictEqual({ + name: 'openai/gpt-4.1-mini', + supports: { + multiturn: true, + tools: true, + media: true, + systemRole: true, + output: ['text', 'json'], }, - expect.any(Function) - ); + }); - defineCompatOpenAIModel({ - ai, + const gpt41nano = defineCompatOpenAIModel({ name: 'openai/gpt-4.1-nano', client: {} as OpenAI, modelRef: testModelRef('openai/gpt-4.1-nano'), }); - expect(ai.defineModel).toHaveBeenCalledWith( - { - name: 'openai/gpt-4.1-nano', - supports: { - multiturn: true, - tools: true, - media: true, - systemRole: true, - output: ['text', 'json'], - }, - configSchema: ChatCompletionCommonConfigSchema, - apiVersion: 'v2', + expect({ + name: gpt41nano.__action.name, + supports: gpt41nano.__action.metadata?.model.supports, + }).toStrictEqual({ + name: 'openai/gpt-4.1-nano', + supports: { + multiturn: true, + tools: true, + media: true, + systemRole: true, + output: ['text', 'json'], }, - expect.any(Function) - ); + }); }); }); diff --git a/js/testapps/compat-oai/src/index.ts b/js/testapps/compat-oai/src/index.ts index 00f65290e5..10fde22780 100644 --- a/js/testapps/compat-oai/src/index.ts +++ b/js/testapps/compat-oai/src/index.ts @@ -40,17 +40,16 @@ const ai = genkit({ name: 'openrouter', baseURL: 'https://openrouter.ai/api/v1', apiKey: process.env['OPENROUTER_API_KEY'], - initializer: async (ai, client) => { - for (const model of DECLARED_MODELS) { + async initializer(client) { + return DECLARED_MODELS.map((model) => defineCompatOpenAIModel({ - ai, name: `openrouter/${model}`, client, modelRef: compatOaiModelRef({ name: `openrouter/${model}`, }), - }); - } + }) + ); }, }), ],