Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 41 additions & 1 deletion go/ai/embedder.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"fmt"

"github.com/firebase/genkit/go/core"
"github.com/firebase/genkit/go/internal/base"
"github.com/firebase/genkit/go/internal/registry"
)

Expand All @@ -32,6 +33,32 @@ type Embedder interface {
Embed(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error)
}

// EmbedderInfo represents the structure of the embedder information object.
type EmbedderInfo struct {
// Label is a user-friendly name for the embedder model (e.g., "Google AI - Gemini Pro").
Label string `json:"label,omitempty"`
// Supports defines the capabilities of the embedder, such as input types and multilingual support.
Supports *EmbedderSupports `json:"supports,omitempty"`
// Dimensions specifies the number of dimensions in the embedding vector.
Dimensions int `json:"dimensions,omitempty"`
}

// EmbedderSupports represents the supported capabilities of the embedder model.
type EmbedderSupports struct {
// Input lists the types of data the model can process (e.g., "text", "image", "video").
Input []string `json:"input,omitempty"`
// Multilingual indicates whether the model supports multiple languages.
Multilingual bool `json:"multilingual,omitempty"`
}

// EmbedderOptions represents the configuration options for an embedder.
type EmbedderOptions struct {
// ConfigSchema defines the schema for the embedder's configuration options.
ConfigSchema any `json:"configSchema,omitempty"`
// Info contains metadata about the embedder, such as its label and capabilities.
Info *EmbedderInfo `json:"info,omitempty"`
}

// An embedder is used to convert a document to a multidimensional vector.
type embedder core.ActionDef[*EmbedRequest, *EmbedResponse, struct{}]

Expand All @@ -40,9 +67,22 @@ type embedder core.ActionDef[*EmbedRequest, *EmbedResponse, struct{}]
func DefineEmbedder(
r *registry.Registry,
provider, name string,
opts *EmbedderOptions,
embed func(context.Context, *EmbedRequest) (*EmbedResponse, error),
) Embedder {
return (*embedder)(core.DefineAction(r, provider, name, core.ActionTypeEmbedder, nil, embed))
metadata := map[string]any{}
metadata["type"] = "embedder"
metadata["info"] = opts.Info
if opts.ConfigSchema != nil {
metadata["embedder"] = map[string]any{"customOptions": base.ToSchemaMap(opts.ConfigSchema)}
}
inputSchema := base.InferJSONSchema(EmbedRequest{})
if inputSchema.Properties != nil && opts.ConfigSchema != nil {
if _, ok := inputSchema.Properties.Get("options"); ok {
inputSchema.Properties.Set("options", base.InferJSONSchema(opts.ConfigSchema))
}
}
return (*embedder)(core.DefineActionWithInputSchema(r, provider, name, core.ActionTypeEmbedder, metadata, inputSchema, embed))
}

// LookupEmbedder looks up an [Embedder] registered by [DefineEmbedder].
Expand Down
32 changes: 30 additions & 2 deletions go/ai/retriever.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"fmt"

"github.com/firebase/genkit/go/core"
"github.com/firebase/genkit/go/internal/base"
"github.com/firebase/genkit/go/internal/registry"
)

Expand All @@ -35,12 +36,39 @@ type Retriever interface {
Retrieve(ctx context.Context, req *RetrieverRequest) (*RetrieverResponse, error)
}

// RetrieverInfo contains metadata about the retriever, such as its label and capabilities.
type RetrieverInfo struct {
// Label is a user-friendly name for the retriever.
Label string `json:"label,omitempty"`
// Supports defines the capabilities of the retriever, such as media support.
Supports *RetrieverSupports `json:"supports,omitempty"`
}

// RetrieverSupports defines the supported capabilities of the retriever.
type RetrieverSupports struct {
// Media indicates whether the retriever supports media content.
Media bool `json:"media,omitempty"`
}

// RetrieverOptions represents the configuration options for a retriever.
type RetrieverOptions struct {
// ConfigSchema holds the configuration schema for the retriever.
ConfigSchema any
// Info contains metadata about the retriever, such as its label and capabilities.
Info *RetrieverInfo
}
type retriever core.ActionDef[*RetrieverRequest, *RetrieverResponse, struct{}]

// DefineRetriever registers the given retrieve function as an action, and returns a
// [Retriever] that runs it.
func DefineRetriever(r *registry.Registry, provider, name string, fn RetrieverFunc) Retriever {
return (*retriever)(core.DefineAction(r, provider, name, core.ActionTypeRetriever, nil, fn))
func DefineRetriever(r *registry.Registry, provider, name string, opts *RetrieverOptions, fn RetrieverFunc) Retriever {
metadata := map[string]any{}
metadata["type"] = "retriever"
metadata["info"] = opts.Info
if opts.ConfigSchema != nil {
metadata["retriever"] = map[string]any{"customOptions": base.InferJSONSchema(opts.ConfigSchema)}
}
return (*retriever)(core.DefineAction(r, provider, name, core.ActionTypeRetriever, metadata, fn))
}

// LookupRetriever looks up a [Retriever] registered by [DefineRetriever].
Expand Down
14 changes: 8 additions & 6 deletions go/core/action.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ type noStream = func(context.Context, struct{}) error
// DefineAction creates a new non-streaming Action and registers it.
func DefineAction[In, Out any](
r *registry.Registry,
provider, name string,
provider,
name string,
atype ActionType,
metadata map[string]any,
fn Func[In, Out],
Expand Down Expand Up @@ -140,24 +141,25 @@ func DefineStreamingAction[In, Out, Stream any](
// This differs from DefineAction in that the input schema is
// defined dynamically; the static input type is "any".
// This is used for prompts and tools that need custom input validation.
func DefineActionWithInputSchema[Out any](
func DefineActionWithInputSchema[In, Out any](
r *registry.Registry,
provider, name string,
atype ActionType,
metadata map[string]any,
inputSchema *jsonschema.Schema,
fn Func[any, Out],
) *ActionDef[any, Out, struct{}] {
fn Func[In, Out],
) *ActionDef[In, Out, struct{}] {
return defineAction(r, provider, name, atype, metadata, inputSchema,
func(ctx context.Context, in any, _ noStream) (Out, error) {
func(ctx context.Context, in In, _ noStream) (Out, error) {
return fn(ctx, in)
})
}

// defineAction creates an action and registers it with the given Registry.
func defineAction[In, Out, Stream any](
r *registry.Registry,
provider, name string,
provider,
name string,
atype ActionType,
metadata map[string]any,
inputSchema *jsonschema.Schema,
Expand Down
8 changes: 4 additions & 4 deletions go/genkit/genkit.go
Original file line number Diff line number Diff line change
Expand Up @@ -746,8 +746,8 @@ func GenerateData[Out any](ctx context.Context, g *Genkit, opts ...ai.GenerateOp
// The `provider` and `name` form the unique identifier. The `ret` function
// contains the logic to process an [ai.RetrieverRequest] (containing the query)
// and return an [ai.RetrieverResponse] (containing the relevant documents).
func DefineRetriever(g *Genkit, provider, name string, ret func(context.Context, *ai.RetrieverRequest) (*ai.RetrieverResponse, error)) ai.Retriever {
return ai.DefineRetriever(g.reg, provider, name, ret)
func DefineRetriever(g *Genkit, provider, name string, opts *ai.RetrieverOptions, ret func(context.Context, *ai.RetrieverRequest) (*ai.RetrieverResponse, error)) ai.Retriever {
return ai.DefineRetriever(g.reg, provider, name, opts, ret)
}

// LookupRetriever retrieves a registered [ai.Retriever] by its provider and name.
Expand All @@ -764,8 +764,8 @@ func LookupRetriever(g *Genkit, provider, name string) ai.Retriever {
// The `provider` and `name` form the unique identifier. The `embed` function
// contains the logic to process an [ai.EmbedRequest] (containing documents or a query)
// and return an [ai.EmbedResponse] (containing the corresponding embeddings).
func DefineEmbedder(g *Genkit, provider, name string, embed func(context.Context, *ai.EmbedRequest) (*ai.EmbedResponse, error)) ai.Embedder {
return ai.DefineEmbedder(g.reg, provider, name, embed)
func DefineEmbedder(g *Genkit, provider string, name string, opts *ai.EmbedderOptions, embed func(context.Context, *ai.EmbedRequest) (*ai.EmbedResponse, error)) ai.Embedder {
return ai.DefineEmbedder(g.reg, provider, name, opts, embed)
}

// LookupEmbedder retrieves a registered [ai.Embedder] by its provider and name.
Expand Down
6 changes: 6 additions & 0 deletions go/internal/base/json.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,9 @@ func GetJsonObjectLines(text string) []string {
// Return the slice containing the filtered and trimmed lines.
return result
}

func ToSchemaMap(config any) map[string]any {
schema := InferJSONSchema(config)
result := SchemaAsMap(schema)
return result
}
11 changes: 10 additions & 1 deletion go/internal/doc-snippets/pinecone.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,19 @@ func pineconeEx(ctx context.Context) error {
var docChunks []*ai.Document

// [START defineretriever]
retOpts := &ai.RetrieverOptions{
ConfigSchema: pinecone.PineconeRetrieverOptions{},
Info: &ai.RetrieverInfo{
Label: "Pinecone",
Supports: &ai.RetrieverSupports{
Media: false,
},
},
}
ds, menuRetriever, err := pinecone.DefineRetriever(ctx, g, pinecone.Config{
IndexID: "menu_data", // Your Pinecone index
Embedder: googlegenai.GoogleAIEmbedder(g, "text-embedding-004"), // Embedding model of your choice
})
}, retOpts)
if err != nil {
return err
}
Expand Down
42 changes: 42 additions & 0 deletions go/internal/doc-snippets/rag/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,23 @@ func main() {
if err != nil {
log.Fatal(err)
}
retOpts := &ai.RetrieverOptions{
ConfigSchema: localvec.RetrieverOptions{},
Info: &ai.RetrieverInfo{
Label: "menuQA",
Supports: &ai.RetrieverSupports{
Media: false,
},
},
}

docStore, _, err := localvec.DefineRetriever(
g,
"menuQA",
localvec.Config{
Embedder: googlegenai.VertexAIEmbedder(g, "text-embedding-004"),
},
retOpts,
)
if err != nil {
log.Fatal(err)
Expand Down Expand Up @@ -155,12 +165,23 @@ func menuQA() {

model := googlegenai.VertexAIModel(g, "gemini-1.5-flash")

retOpts := &ai.RetrieverOptions{
ConfigSchema: localvec.RetrieverOptions{},
Info: &ai.RetrieverInfo{
Label: "menuQA",
Supports: &ai.RetrieverSupports{
Media: false,
},
},
}

_, menuPdfRetriever, err := localvec.DefineRetriever(
g,
"menuQA",
localvec.Config{
Embedder: googlegenai.VertexAIEmbedder(g, "text-embedding-004"),
},
retOpts,
)
if err != nil {
log.Fatal(err)
Expand Down Expand Up @@ -207,23 +228,44 @@ func customret() {
log.Fatal(err)
}

retOpts := &ai.RetrieverOptions{
ConfigSchema: localvec.RetrieverOptions{},
Info: &ai.RetrieverInfo{
Label: "menuQA",
Supports: &ai.RetrieverSupports{
Media: false,
},
},
}

_, menuPDFRetriever, _ := localvec.DefineRetriever(
g,
"menuQA",
localvec.Config{
Embedder: googlegenai.VertexAIEmbedder(g, "text-embedding-004"),
},
retOpts,
)

// [START customret]
type CustomMenuRetrieverOptions struct {
K int
PreRerankK int
}
genRetOpts := &ai.RetrieverOptions{
ConfigSchema: CustomMenuRetrieverOptions{},
Info: &ai.RetrieverInfo{
Label: "advancedMenuRetriever",
Supports: &ai.RetrieverSupports{
Media: false,
},
},
}
advancedMenuRetriever := genkit.DefineRetriever(
g,
"custom",
"advancedMenuRetriever",
genRetOpts,
func(ctx context.Context, req *ai.RetrieverRequest) (*ai.RetrieverResponse, error) {
// Handle options passed using our custom type.
opts, _ := req.Options.(CustomMenuRetrieverOptions)
Expand Down
12 changes: 11 additions & 1 deletion go/internal/fakeembedder/fakeembedder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,17 @@ func TestFakeEmbedder(t *testing.T) {
}

embed := New()
emb := ai.DefineEmbedder(r, "fake", "embed", embed.Embed)
emdOpts := &ai.EmbedderOptions{
Info: &ai.EmbedderInfo{
Dimensions: 32,
Label: "embed",
Supports: &ai.EmbedderSupports{
Input: []string{"text"},
},
},
ConfigSchema: nil,
}
emb := ai.DefineEmbedder(r, "fake", "embed", emdOpts, embed.Embed)
d := ai.DocumentFromText("fakeembedder test", nil)

vals := []float32{1, 2}
Expand Down
2 changes: 1 addition & 1 deletion go/plugins/alloydb/genkit.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func DefineRetriever(ctx context.Context, g *genkit.Genkit, p *Postgres, cfg *Co
return nil, nil, err
}

return ds, genkit.DefineRetriever(g, provider, ds.config.TableName, ds.Retrieve), nil
return ds, genkit.DefineRetriever(g, provider, ds.config.TableName, nil, ds.Retrieve), nil
}

// Retriever returns the retriever with the given index name.
Expand Down
2 changes: 1 addition & 1 deletion go/plugins/alloydb/genkit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ func TestPostgres(t *testing.T) {
IDColumn: CustomIdColumn,
MetadataJSONColumn: CustomMetadataColumn,
IgnoreMetadataColumns: []string{"created_at", "updated_at"},
Embedder: genkit.DefineEmbedder(g, "fake", "embedder3", embedder.Embed),
Embedder: genkit.DefineEmbedder(g, "fake", "embedder3", nil, embedder.Embed),
EmbedderOptions: nil,
}

Expand Down
4 changes: 2 additions & 2 deletions go/plugins/compat_oai/compat_oai.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,14 @@ func (o *OpenAICompatible) DefineModel(g *genkit.Genkit, provider, name string,
}

// DefineEmbedder defines an embedder with a given name.
func (o *OpenAICompatible) DefineEmbedder(g *genkit.Genkit, provider, name string) (ai.Embedder, error) {
func (o *OpenAICompatible) DefineEmbedder(g *genkit.Genkit, provider, name string, embedOpts *ai.EmbedderOptions) (ai.Embedder, error) {
o.mu.Lock()
defer o.mu.Unlock()
if !o.initted {
return nil, errors.New("OpenAICompatible.Init not called")
}

return genkit.DefineEmbedder(g, provider, name, func(ctx context.Context, input *ai.EmbedRequest) (*ai.EmbedResponse, error) {
return genkit.DefineEmbedder(g, provider, name, embedOpts, func(ctx context.Context, input *ai.EmbedRequest) (*ai.EmbedResponse, error) {
var data openaiGo.EmbeddingNewParamsInputArrayOfStrings
for _, doc := range input.Input {
for _, p := range doc.Content {
Expand Down
Loading
Loading