Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 169 additions & 21 deletions genkit-tools/common/src/eval/evaluate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@
*/

import { randomUUID } from 'crypto';
import { z } from 'zod';
import { getDatasetStore, getEvalStore } from '.';
import type { RuntimeManager } from '../manager/manager';
import {
DatasetSchema,
GenerateActionOptions,
GenerateActionOptionsSchema,
GenerateResponseSchema,
type Action,
type CandidateData,
Expand All @@ -33,6 +36,7 @@ import {
import {
evaluatorName,
generateTestCaseId,
getAction,
getEvalExtractors,
getModelInput,
hasAction,
Expand All @@ -50,7 +54,7 @@ interface InferenceRunState {
testCaseId: string;
input: any;
reference?: any;
traceId?: string;
traceIds: string[];
response?: any;
evalError?: string;
}
Expand All @@ -61,8 +65,8 @@ interface FullInferenceSample {
reference?: any;
}

const SUPPORTED_ACTION_TYPES = ['flow', 'model'] as const;

const SUPPORTED_ACTION_TYPES = ['flow', 'model', 'executable-prompt'] as const;
type SupportedActionType = (typeof SUPPORTED_ACTION_TYPES)[number];
/**
* Starts a new evaluation run. Intended to be used via the reflection API.
*/
Expand Down Expand Up @@ -253,7 +257,7 @@ async function bulkRunAction(params: {
}): Promise<EvalInput[]> {
const { manager, actionRef, inferenceDataset, context, actionConfig } =
params;
const isModelAction = actionRef.startsWith('/model');
const actionType = getSupportedActionType(actionRef);
if (inferenceDataset.length === 0) {
throw new Error('Cannot run inference, no data provided');
}
Expand All @@ -267,7 +271,7 @@ async function bulkRunAction(params: {
const evalInputs: EvalInput[] = [];
for (const sample of fullInferenceDataset) {
logger.info(`Running inference '${actionRef}' ...`);
if (isModelAction) {
if (actionType === 'model') {
states.push(
await runModelAction({
manager,
Expand All @@ -276,7 +280,7 @@ async function bulkRunAction(params: {
modelConfig: actionConfig,
})
);
} else {
} else if (actionType === 'flow') {
states.push(
await runFlowAction({
manager,
Expand All @@ -285,6 +289,17 @@ async function bulkRunAction(params: {
context,
})
);
} else {
// executable-prompt action
states.push(
await runPromptAction({
manager,
actionRef,
sample,
context,
promptConfig: actionConfig,
})
);
}
}

Expand All @@ -311,14 +326,16 @@ async function runFlowAction(params: {
});
state = {
...sample,
traceId: runActionResponse.telemetry?.traceId,
traceIds: runActionResponse.telemetry?.traceId
? [runActionResponse.telemetry?.traceId]
: [],
response: runActionResponse.result,
};
} catch (e: any) {
const traceId = e?.data?.details?.traceId;
state = {
...sample,
traceId,
traceIds: traceId ? [traceId] : [],
evalError: `Error when running inference. Details: ${e?.message ?? e}`,
};
}
Expand All @@ -341,14 +358,98 @@ async function runModelAction(params: {
});
state = {
...sample,
traceId: runActionResponse.telemetry?.traceId,
traceIds: runActionResponse.telemetry?.traceId
? [runActionResponse.telemetry?.traceId]
: [],
response: runActionResponse.result,
};
} catch (e: any) {
const traceId = e?.data?.details?.traceId;
state = {
...sample,
traceIds: traceId ? [traceId] : [],
evalError: `Error when running inference. Details: ${e?.message ?? e}`,
};
}
return state;
}

async function runPromptAction(params: {
manager: RuntimeManager;
actionRef: string;
sample: FullInferenceSample;
context?: any;
promptConfig?: any;
}): Promise<InferenceRunState> {
const { manager, actionRef, sample, context, promptConfig } = { ...params };

const { model: modelFromConfig, ...restOfConfig } = promptConfig ?? {};
const model = await resolveModel({ manager, actionRef, modelFromConfig });
if (!model) {
throw new Error(
'Could not resolve model. Please specify model and try again'
);
}
let state: InferenceRunState;
let renderedPrompt: {
result: GenerateActionOptions;
traceId: string;
};
// Step 1. Render the prompt with inputs
try {
const runActionResponse = await manager.runAction({
key: actionRef,
input: sample.input,
context: context ? JSON.parse(context) : undefined,
});

renderedPrompt = {
traceId: runActionResponse.telemetry?.traceId!,
result: GenerateActionOptionsSchema.parse(runActionResponse.result),
};
} catch (e: any) {
if (e instanceof z.ZodError) {
state = {
...sample,
traceIds: [],
evalError: `Error parsing prompt response. Details: ${JSON.stringify(e.format())}`,
};
} else {
const traceId = e?.data?.details?.traceId;
state = {
...sample,
traceIds: traceId ? [traceId] : [],
evalError: `Error when rendering prompt. Details: ${e?.message ?? e}`,
};
}
return state;
}
// Step 2. Run rendered prompt on the model
try {
let modelInput = renderedPrompt.result;
if (restOfConfig) {
modelInput = { ...modelInput, config: restOfConfig };
}
const runActionResponse = await manager.runAction({
key: model,
input: modelInput,
});
const traceIds = runActionResponse.telemetry?.traceId
? [renderedPrompt.traceId, runActionResponse.telemetry?.traceId]
: [renderedPrompt.traceId];
state = {
...sample,
traceIds: traceIds,
response: runActionResponse.result,
};
} catch (e: any) {
const traceId = e?.data?.details?.traceId;
const traceIds = traceId
? [renderedPrompt.traceId, traceId]
: [renderedPrompt.traceId];
state = {
...sample,
traceId,
traceIds: traceIds.filter((t): t is string => !!t),
evalError: `Error when running inference. Details: ${e?.message ?? e}`,
};
}
Expand All @@ -362,25 +463,38 @@ async function gatherEvalInput(params: {
}): Promise<EvalInput> {
const { manager, actionRef, state } = params;

const actionType = getSupportedActionType(actionRef);
const extractors = await getEvalExtractors(actionRef);
const traceId = state.traceId;
if (!traceId) {
logger.warn('No traceId available...');
const traceIds = state.traceIds;

if (
traceIds.length === 0 ||
(actionType === 'executable-prompt' && traceIds.length < 2)
) {
logger.warn('No valid traceId available...');
return {
...state,
error: state.evalError,
testCaseId: state.testCaseId,
traceIds: [],
traceIds: traceIds,
};
}

// Only the last collected trace to be used for evaluation.
const traceId = traceIds.at(-1)!;
const trace = await manager.getTrace({
traceId,
});

const isModelAction = actionRef.startsWith('/model');
// Always use original input for models.
const input = isModelAction ? state.input : extractors.input(trace);
// Always use original input for models and prompts.
const input = actionType === 'flow' ? extractors.input(trace) : state.input;
let custom = undefined;
if (actionType === 'executable-prompt') {
const promptTrace = await manager.getTrace({
traceId: traceIds[0],
});
custom = { renderedPrompt: extractors.output(promptTrace) };
}

const nestedSpan = stackTraceSpans(trace);
if (!nestedSpan) {
Expand All @@ -389,7 +503,8 @@ async function gatherEvalInput(params: {
input,
error: `Unable to extract any spans from trace ${traceId}`,
reference: state.reference,
traceIds: [traceId],
custom,
traceIds: traceIds,
};
}

Expand All @@ -400,13 +515,15 @@ async function gatherEvalInput(params: {
error:
getSpanErrorMessage(nestedSpan) ?? `Unknown error in trace ${traceId}`,
reference: state.reference,
traceIds: [traceId],
custom,
traceIds: traceIds,
};
}

const output = extractors.output(trace);
const context = extractors.context(trace);
const error = isModelAction ? getErrorFromModelResponse(output) : undefined;
const error =
actionType === 'model' ? getErrorFromModelResponse(output) : undefined;

return {
// TODO Replace this with unified trace class
Expand All @@ -416,10 +533,28 @@ async function gatherEvalInput(params: {
error,
context: Array.isArray(context) ? context : [context],
reference: state.reference,
traceIds: [traceId],
custom,
traceIds: traceIds,
};
}

async function resolveModel(params: {
manager: RuntimeManager;
actionRef: string;
modelFromConfig?: string;
}) {
const { manager, actionRef, modelFromConfig } = { ...params };

// Prefer to use modelFromConfig
if (modelFromConfig) {
return modelFromConfig;
}

const actionData = await getAction({ manager, actionRef });
const promptMetadata = actionData?.metadata?.prompt as any;
return promptMetadata?.model ? `/model/${promptMetadata?.model}` : undefined;
}

function getSpanErrorMessage(span: SpanData): string | undefined {
if (span && span.status?.code === 2 /* SpanStatusCode.ERROR */) {
// It's possible for a trace to have multiple exception events,
Expand Down Expand Up @@ -466,3 +601,16 @@ function isSupportedActionRef(actionRef: string) {
actionRef.startsWith(`/${supportedType}`)
);
}

function getSupportedActionType(actionRef: string): SupportedActionType {
if (actionRef.startsWith('/model')) {
return 'model';
}
if (actionRef.startsWith('/flow')) {
return 'flow';
}
if (actionRef.startsWith('/executable-prompt')) {
return 'executable-prompt';
}
throw new Error(`Unsupported action type: ${actionRef}`);
}
8 changes: 7 additions & 1 deletion genkit-tools/common/src/types/eval.ts
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ export const EvalInputSchema = z.object({
error: z.string().optional(),
context: z.array(z.any()).optional(),
reference: z.any().optional(),
custom: z.record(z.string(), z.any()).optional(),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to run py/bin/generate_schema_typing?

traceIds: z.array(z.string()),
});
export type EvalInput = z.infer<typeof EvalInputSchema>;
Expand Down Expand Up @@ -251,7 +252,12 @@ export const DatasetSchemaSchema = z.object({
});

/** Type of dataset, useful for UI niceties. */
export const DatasetTypeSchema = z.enum(['UNKNOWN', 'FLOW', 'MODEL']);
export const DatasetTypeSchema = z.enum([
'UNKNOWN',
'FLOW',
'MODEL',
'EXECUTABLE_PROMPT',
]);
export type DatasetType = z.infer<typeof DatasetTypeSchema>;

/**
Expand Down
12 changes: 11 additions & 1 deletion genkit-tools/common/src/utils/eval.ts
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,16 @@ export async function hasAction(params: {
return actionsRecord.hasOwnProperty(actionRef);
}

export async function getAction(params: {
manager: RuntimeManager;
actionRef: string;
}): Promise<Action | undefined> {
const { manager, actionRef } = { ...params };
const allActions = await manager.listActions();

return Object.values(allActions).find((action) => action.key === actionRef);
}

/** Helper function that maps string data to GenerateRequest */
export function getModelInput(data: any, modelConfig: any): GenerateRequest {
let message: MessageData;
Expand All @@ -351,7 +361,7 @@ export function getModelInput(data: any, modelConfig: any): GenerateRequest {
} else {
const maybeRequest = GenerateRequestSchema.safeParse(data);
if (maybeRequest.success) {
return maybeRequest.data;
return { ...maybeRequest.data, config: modelConfig };
} else {
throw new Error(
`Unable to parse model input as MessageSchema. Details: ${maybeRequest.error}`
Expand Down
14 changes: 14 additions & 0 deletions js/testapps/evals/prompts/hello.prompt
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
---
model: googleai/gemini-2.5-flash
config:
temperature: 0.75
input:
schema:
query: string
---

{{role "system"}}
Only write code, do not explain

{{role "user"}}
Assist the user with: {{query}}
Loading
Loading