diff --git a/genkit-tools/common/src/eval/evaluate.ts b/genkit-tools/common/src/eval/evaluate.ts index 65292e1f2b..ddf99e3361 100644 --- a/genkit-tools/common/src/eval/evaluate.ts +++ b/genkit-tools/common/src/eval/evaluate.ts @@ -15,10 +15,13 @@ */ import { randomUUID } from 'crypto'; +import { z } from 'zod'; import { getDatasetStore, getEvalStore } from '.'; import type { RuntimeManager } from '../manager/manager'; import { DatasetSchema, + GenerateActionOptions, + GenerateActionOptionsSchema, GenerateResponseSchema, type Action, type CandidateData, @@ -33,6 +36,7 @@ import { import { evaluatorName, generateTestCaseId, + getAction, getEvalExtractors, getModelInput, hasAction, @@ -50,7 +54,7 @@ interface InferenceRunState { testCaseId: string; input: any; reference?: any; - traceId?: string; + traceIds: string[]; response?: any; evalError?: string; } @@ -61,8 +65,8 @@ interface FullInferenceSample { reference?: any; } -const SUPPORTED_ACTION_TYPES = ['flow', 'model'] as const; - +const SUPPORTED_ACTION_TYPES = ['flow', 'model', 'executable-prompt'] as const; +type SupportedActionType = (typeof SUPPORTED_ACTION_TYPES)[number]; /** * Starts a new evaluation run. Intended to be used via the reflection API. */ @@ -253,7 +257,7 @@ async function bulkRunAction(params: { }): Promise { const { manager, actionRef, inferenceDataset, context, actionConfig } = params; - const isModelAction = actionRef.startsWith('/model'); + const actionType = getSupportedActionType(actionRef); if (inferenceDataset.length === 0) { throw new Error('Cannot run inference, no data provided'); } @@ -267,7 +271,7 @@ async function bulkRunAction(params: { const evalInputs: EvalInput[] = []; for (const sample of fullInferenceDataset) { logger.info(`Running inference '${actionRef}' ...`); - if (isModelAction) { + if (actionType === 'model') { states.push( await runModelAction({ manager, @@ -276,7 +280,7 @@ async function bulkRunAction(params: { modelConfig: actionConfig, }) ); - } else { + } else if (actionType === 'flow') { states.push( await runFlowAction({ manager, @@ -285,6 +289,17 @@ async function bulkRunAction(params: { context, }) ); + } else { + // executable-prompt action + states.push( + await runPromptAction({ + manager, + actionRef, + sample, + context, + promptConfig: actionConfig, + }) + ); } } @@ -311,14 +326,16 @@ async function runFlowAction(params: { }); state = { ...sample, - traceId: runActionResponse.telemetry?.traceId, + traceIds: runActionResponse.telemetry?.traceId + ? [runActionResponse.telemetry?.traceId] + : [], response: runActionResponse.result, }; } catch (e: any) { const traceId = e?.data?.details?.traceId; state = { ...sample, - traceId, + traceIds: traceId ? [traceId] : [], evalError: `Error when running inference. Details: ${e?.message ?? e}`, }; } @@ -341,14 +358,98 @@ async function runModelAction(params: { }); state = { ...sample, - traceId: runActionResponse.telemetry?.traceId, + traceIds: runActionResponse.telemetry?.traceId + ? [runActionResponse.telemetry?.traceId] + : [], + response: runActionResponse.result, + }; + } catch (e: any) { + const traceId = e?.data?.details?.traceId; + state = { + ...sample, + traceIds: traceId ? [traceId] : [], + evalError: `Error when running inference. Details: ${e?.message ?? e}`, + }; + } + return state; +} + +async function runPromptAction(params: { + manager: RuntimeManager; + actionRef: string; + sample: FullInferenceSample; + context?: any; + promptConfig?: any; +}): Promise { + const { manager, actionRef, sample, context, promptConfig } = { ...params }; + + const { model: modelFromConfig, ...restOfConfig } = promptConfig ?? {}; + const model = await resolveModel({ manager, actionRef, modelFromConfig }); + if (!model) { + throw new Error( + 'Could not resolve model. Please specify model and try again' + ); + } + let state: InferenceRunState; + let renderedPrompt: { + result: GenerateActionOptions; + traceId: string; + }; + // Step 1. Render the prompt with inputs + try { + const runActionResponse = await manager.runAction({ + key: actionRef, + input: sample.input, + context: context ? JSON.parse(context) : undefined, + }); + + renderedPrompt = { + traceId: runActionResponse.telemetry?.traceId!, + result: GenerateActionOptionsSchema.parse(runActionResponse.result), + }; + } catch (e: any) { + if (e instanceof z.ZodError) { + state = { + ...sample, + traceIds: [], + evalError: `Error parsing prompt response. Details: ${JSON.stringify(e.format())}`, + }; + } else { + const traceId = e?.data?.details?.traceId; + state = { + ...sample, + traceIds: traceId ? [traceId] : [], + evalError: `Error when rendering prompt. Details: ${e?.message ?? e}`, + }; + } + return state; + } + // Step 2. Run rendered prompt on the model + try { + let modelInput = renderedPrompt.result; + if (restOfConfig) { + modelInput = { ...modelInput, config: restOfConfig }; + } + const runActionResponse = await manager.runAction({ + key: model, + input: modelInput, + }); + const traceIds = runActionResponse.telemetry?.traceId + ? [renderedPrompt.traceId, runActionResponse.telemetry?.traceId] + : [renderedPrompt.traceId]; + state = { + ...sample, + traceIds: traceIds, response: runActionResponse.result, }; } catch (e: any) { const traceId = e?.data?.details?.traceId; + const traceIds = traceId + ? [renderedPrompt.traceId, traceId] + : [renderedPrompt.traceId]; state = { ...sample, - traceId, + traceIds: traceIds.filter((t): t is string => !!t), evalError: `Error when running inference. Details: ${e?.message ?? e}`, }; } @@ -362,25 +463,38 @@ async function gatherEvalInput(params: { }): Promise { const { manager, actionRef, state } = params; + const actionType = getSupportedActionType(actionRef); const extractors = await getEvalExtractors(actionRef); - const traceId = state.traceId; - if (!traceId) { - logger.warn('No traceId available...'); + const traceIds = state.traceIds; + + if ( + traceIds.length === 0 || + (actionType === 'executable-prompt' && traceIds.length < 2) + ) { + logger.warn('No valid traceId available...'); return { ...state, error: state.evalError, testCaseId: state.testCaseId, - traceIds: [], + traceIds: traceIds, }; } + // Only the last collected trace to be used for evaluation. + const traceId = traceIds.at(-1)!; const trace = await manager.getTrace({ traceId, }); - const isModelAction = actionRef.startsWith('/model'); - // Always use original input for models. - const input = isModelAction ? state.input : extractors.input(trace); + // Always use original input for models and prompts. + const input = actionType === 'flow' ? extractors.input(trace) : state.input; + let custom = undefined; + if (actionType === 'executable-prompt') { + const promptTrace = await manager.getTrace({ + traceId: traceIds[0], + }); + custom = { renderedPrompt: extractors.output(promptTrace) }; + } const nestedSpan = stackTraceSpans(trace); if (!nestedSpan) { @@ -389,7 +503,8 @@ async function gatherEvalInput(params: { input, error: `Unable to extract any spans from trace ${traceId}`, reference: state.reference, - traceIds: [traceId], + custom, + traceIds: traceIds, }; } @@ -400,13 +515,15 @@ async function gatherEvalInput(params: { error: getSpanErrorMessage(nestedSpan) ?? `Unknown error in trace ${traceId}`, reference: state.reference, - traceIds: [traceId], + custom, + traceIds: traceIds, }; } const output = extractors.output(trace); const context = extractors.context(trace); - const error = isModelAction ? getErrorFromModelResponse(output) : undefined; + const error = + actionType === 'model' ? getErrorFromModelResponse(output) : undefined; return { // TODO Replace this with unified trace class @@ -416,10 +533,28 @@ async function gatherEvalInput(params: { error, context: Array.isArray(context) ? context : [context], reference: state.reference, - traceIds: [traceId], + custom, + traceIds: traceIds, }; } +async function resolveModel(params: { + manager: RuntimeManager; + actionRef: string; + modelFromConfig?: string; +}) { + const { manager, actionRef, modelFromConfig } = { ...params }; + + // Prefer to use modelFromConfig + if (modelFromConfig) { + return modelFromConfig; + } + + const actionData = await getAction({ manager, actionRef }); + const promptMetadata = actionData?.metadata?.prompt as any; + return promptMetadata?.model ? `/model/${promptMetadata?.model}` : undefined; +} + function getSpanErrorMessage(span: SpanData): string | undefined { if (span && span.status?.code === 2 /* SpanStatusCode.ERROR */) { // It's possible for a trace to have multiple exception events, @@ -466,3 +601,16 @@ function isSupportedActionRef(actionRef: string) { actionRef.startsWith(`/${supportedType}`) ); } + +function getSupportedActionType(actionRef: string): SupportedActionType { + if (actionRef.startsWith('/model')) { + return 'model'; + } + if (actionRef.startsWith('/flow')) { + return 'flow'; + } + if (actionRef.startsWith('/executable-prompt')) { + return 'executable-prompt'; + } + throw new Error(`Unsupported action type: ${actionRef}`); +} diff --git a/genkit-tools/common/src/types/eval.ts b/genkit-tools/common/src/types/eval.ts index ac6eddb9dd..56590889ae 100644 --- a/genkit-tools/common/src/types/eval.ts +++ b/genkit-tools/common/src/types/eval.ts @@ -134,6 +134,7 @@ export const EvalInputSchema = z.object({ error: z.string().optional(), context: z.array(z.any()).optional(), reference: z.any().optional(), + custom: z.record(z.string(), z.any()).optional(), traceIds: z.array(z.string()), }); export type EvalInput = z.infer; @@ -251,7 +252,12 @@ export const DatasetSchemaSchema = z.object({ }); /** Type of dataset, useful for UI niceties. */ -export const DatasetTypeSchema = z.enum(['UNKNOWN', 'FLOW', 'MODEL']); +export const DatasetTypeSchema = z.enum([ + 'UNKNOWN', + 'FLOW', + 'MODEL', + 'EXECUTABLE_PROMPT', +]); export type DatasetType = z.infer; /** diff --git a/genkit-tools/common/src/utils/eval.ts b/genkit-tools/common/src/utils/eval.ts index c1944c3e5a..8f0eefe5e6 100644 --- a/genkit-tools/common/src/utils/eval.ts +++ b/genkit-tools/common/src/utils/eval.ts @@ -332,6 +332,16 @@ export async function hasAction(params: { return actionsRecord.hasOwnProperty(actionRef); } +export async function getAction(params: { + manager: RuntimeManager; + actionRef: string; +}): Promise { + const { manager, actionRef } = { ...params }; + const allActions = await manager.listActions(); + + return Object.values(allActions).find((action) => action.key === actionRef); +} + /** Helper function that maps string data to GenerateRequest */ export function getModelInput(data: any, modelConfig: any): GenerateRequest { let message: MessageData; @@ -351,7 +361,7 @@ export function getModelInput(data: any, modelConfig: any): GenerateRequest { } else { const maybeRequest = GenerateRequestSchema.safeParse(data); if (maybeRequest.success) { - return maybeRequest.data; + return { ...maybeRequest.data, config: modelConfig }; } else { throw new Error( `Unable to parse model input as MessageSchema. Details: ${maybeRequest.error}` diff --git a/js/testapps/evals/prompts/hello.prompt b/js/testapps/evals/prompts/hello.prompt new file mode 100644 index 0000000000..0ef26c9c6d --- /dev/null +++ b/js/testapps/evals/prompts/hello.prompt @@ -0,0 +1,14 @@ +--- +model: googleai/gemini-2.5-flash +config: + temperature: 0.75 +input: + schema: + query: string +--- + +{{role "system"}} +Only write code, do not explain + +{{role "user"}} +Assist the user with: {{query}} \ No newline at end of file diff --git a/js/testapps/evals/src/genkit.ts b/js/testapps/evals/src/genkit.ts index f03cf816da..e59f0f13b6 100644 --- a/js/testapps/evals/src/genkit.ts +++ b/js/testapps/evals/src/genkit.ts @@ -22,11 +22,6 @@ import { googleAI, textEmbeddingGecko001, } from '@genkit-ai/googleai'; -import { vertexAI } from '@genkit-ai/vertexai'; -import { - VertexAIEvaluationMetricType, - vertexAIEvaluation, -} from '@genkit-ai/vertexai/evaluation'; import { genkit } from 'genkit'; import { langchain } from 'genkitx-langchain'; @@ -70,25 +65,6 @@ export const ai = genkit({ }, ], }), - vertexAI({ - location: 'us-central1', - }), - vertexAIEvaluation({ - location: 'us-central1', - metrics: [ - VertexAIEvaluationMetricType.BLEU, - VertexAIEvaluationMetricType.GROUNDEDNESS, - VertexAIEvaluationMetricType.SAFETY, - { - type: VertexAIEvaluationMetricType.ROUGE, - metricSpec: { - rougeType: 'rougeLsum', - useStemmer: true, - splitSummaries: 'true', - }, - }, - ], - }), devLocalVectorstore([ { indexName: 'pdfQA',