Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 52 additions & 2 deletions go/plugins/googlegenai/gemini.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package googlegenai

import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
Expand Down Expand Up @@ -63,7 +64,7 @@ var (

// Attribution header
xGoogApiClientHeader = http.CanonicalHeaderKey("x-goog-api-client")
GenkitClientHeader = http.Header{
genkitClientHeader = http.Header{
xGoogApiClientHeader: {fmt.Sprintf("genkit-go/%s", internal.Version)},
}
)
Expand Down Expand Up @@ -174,6 +175,15 @@ type SafetySetting struct {
Threshold HarmBlockThreshold `json:"threshold,omitempty"`
}

type Modality string

const (
// Indicates the model should return images
ImageMode Modality = "IMAGE"
// Indicates the model should return text
TextMode Modality = "TEXT"
)

// GeminiConfig mirrors GenerateContentConfig without direct genai dependency
type GeminiConfig struct {
// MaxOutputTokens is the maximum number of tokens to generate.
Expand All @@ -192,6 +202,8 @@ type GeminiConfig struct {
SafetySettings []*SafetySetting `json:"safetySettings,omitempty"`
// CodeExecution is whether to allow executing of code generated by the model.
CodeExecution bool `json:"codeExecution,omitempty"`
// Response modalities for returned model messages
ResponseModalities []Modality `json:"responseModalities,omitempty"`
}

// configFromRequest converts any supported config type to [GeminiConfig].
Expand Down Expand Up @@ -333,6 +345,23 @@ func generate(
return nil, err
}

if len(config.ResponseModalities) > 0 {
err := validateResponseModalities(model, config.ResponseModalities)
if err != nil {
return nil, err
}
for _, m := range config.ResponseModalities {
gcc.ResponseModalities = append(gcc.ResponseModalities, string(m))
}

// prevent an error in the client where:
// if TEXT modality is not present and the model supports it, the client
// will return an error
if !slices.Contains(gcc.ResponseModalities, string(genai.ModalityText)) {
gcc.ResponseModalities = append(gcc.ResponseModalities, string(genai.ModalityText))
}
}

var contents []*genai.Content
for _, m := range input.Messages {
// system parts are handled separately
Expand Down Expand Up @@ -523,6 +552,23 @@ func convertRequest(input *ai.ModelRequest, cache *genai.CachedContent) (*genai.
return &gcc, nil
}

// validateResponseModalities checks if response modality is valid for the requested model
func validateResponseModalities(model string, modalities []Modality) error {
for _, m := range modalities {
switch m {
case ImageMode:
if !slices.Contains(imageGenModels, model) {
return fmt.Errorf("IMAGE response modality is not supported for model %q", model)
}
case TextMode:
continue
default:
return fmt.Errorf("unknown response modality provided: %q", m)
}
}
return nil
}

// toGeminiTools translates a slice of [ai.ToolDefinition] to a slice of [genai.Tool].
func toGeminiTools(inTools []*ai.ToolDefinition) ([]*genai.Tool, error) {
var outTools []*genai.Tool
Expand Down Expand Up @@ -724,7 +770,11 @@ func translateCandidate(cand *genai.Candidate) *ai.ModelResponse {
}
if part.InlineData != nil {
partFound++
p = ai.NewMediaPart(part.InlineData.MIMEType, string(part.InlineData.Data))
p = ai.NewMediaPart(part.InlineData.MIMEType, base64.StdEncoding.EncodeToString(part.InlineData.Data))
}
if part.FileData != nil {
partFound++
p = ai.NewMediaPart(part.FileData.MIMEType, part.FileData.FileURI)
}
if part.FunctionCall != nil {
partFound++
Expand Down
28 changes: 28 additions & 0 deletions go/plugins/googlegenai/googleai_live_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,34 @@ func TestGoogleAILive(t *testing.T) {
t.Fatalf("image detection failed, want: Mario Kart, got: %s", resp.Text())
}
})
t.Run("image generation", func(t *testing.T) {
m := googlegenai.GoogleAIModel(g, "gemini-2.0-flash-exp")
resp, err := genkit.Generate(ctx, g,
ai.WithConfig(googlegenai.GeminiConfig{
ResponseModalities: []googlegenai.Modality{googlegenai.ImageMode, googlegenai.TextMode},
}),
ai.WithMessages(
ai.NewUserTextMessage("generate an image of a dog wearing a black tejana while playing the accordion"),
),
ai.WithModel(m),
)
if err != nil {
t.Fatal(err)
}
if len(resp.Message.Content) == 0 {
t.Fatal("empty response")
}
part := resp.Message.Content[0]
if part.ContentType != "image/png" {
t.Errorf("expecting image/png content type but got: %q", part.ContentType)
}
if part.Kind != ai.PartMedia {
t.Errorf("expecting part to be Media type but got: %q", part.Kind)
}
if part.Text == "" {
t.Errorf("empty response")
}
})
t.Run("constrained generation", func(t *testing.T) {
type outFormat struct {
Country string
Expand Down
4 changes: 2 additions & 2 deletions go/plugins/googlegenai/googlegenai.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func (ga *GoogleAI) Init(ctx context.Context, g *genkit.Genkit) (err error) {
Backend: genai.BackendGeminiAPI,
APIKey: apiKey,
HTTPOptions: genai.HTTPOptions{
Headers: GenkitClientHeader,
Headers: genkitClientHeader,
},
}

Expand Down Expand Up @@ -159,7 +159,7 @@ func (v *VertexAI) Init(ctx context.Context, g *genkit.Genkit) (err error) {
Project: v.ProjectID,
Location: v.Location,
HTTPOptions: genai.HTTPOptions{
Headers: GenkitClientHeader,
Headers: genkitClientHeader,
},
}

Expand Down
13 changes: 13 additions & 0 deletions go/plugins/googlegenai/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ const (
gemini15Flash8b = "gemini-1.5-flash-8b"

gemini20Flash = "gemini-2.0-flash"
gemini20FlashExp = "gemini-2.0-flash-exp"
gemini20FlashLite = "gemini-2.0-flash-lite"
gemini20FlashLitePrev = "gemini-2.0-flash-lite-preview"
gemini20ProExp0205 = "gemini-2.0-pro-exp-02-05"
Expand Down Expand Up @@ -45,13 +46,19 @@ var (
gemini15Pro,
gemini15Flash8b,
gemini20Flash,
gemini20FlashExp,
gemini20FlashLitePrev,
gemini20ProExp0205,
gemini20FlashThinkingExp0121,
gemini25ProExp0325,
gemini25ProPreview0325,
}

// models with native image support generation
imageGenModels = []string{
gemini20FlashExp,
}

supportedGeminiModels = map[string]ai.ModelInfo{
gemini15Flash: {
Label: "Gemini 1.5 Flash",
Expand Down Expand Up @@ -90,6 +97,12 @@ var (
Supports: &Multimodal,
Stage: ai.ModelStageStable,
},
gemini20FlashExp: {
Label: "Gemini 2.0 Flash Exp",
Versions: []string{},
Supports: &Multimodal,
Stage: ai.ModelStageUnstable,
},
gemini20FlashLite: {
Label: "Gemini 2.0 Flash Lite",
Versions: []string{
Expand Down
107 changes: 107 additions & 0 deletions go/samples/imagen-gemini/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package main

import (
"context"
"encoding/base64"
"errors"
"fmt"
"log"
"os"

"github.com/firebase/genkit/go/ai"
"github.com/firebase/genkit/go/genkit"
"github.com/firebase/genkit/go/plugins/googlegenai"
)

func main() {
ctx := context.Background()

// Initialize Genkit with the Google AI plugin. When you pass nil for the
// Config parameter, the Google AI plugin will get the API key from the
// GEMINI_API_KEY or GOOGLE_API_KEY environment variable, which is the recommended
// practice.
g, err := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.GoogleAI{}))
if err != nil {
log.Fatal(err)
}

// Define a simple flow that generates an image of a given topic
genkit.DefineFlow(g, "imageFlow", func(ctx context.Context, input string) (string, error) {
m := googlegenai.GoogleAIModel(g, "gemini-2.0-flash-exp")
if m == nil {
return "", errors.New("imageFlow: failed to find model")
}

if input == "" {
input = `A little blue gopher with big eyes trying to learn Python,
use a cartoon style, the story should be tragic because he
chose the wrong programming language, the proper programing
language for a gopher should be Go`
}
resp, err := genkit.Generate(ctx, g,
ai.WithModel(m),
ai.WithConfig(&googlegenai.GeminiConfig{
Temperature: 0.5,
ResponseModalities: []googlegenai.Modality{
googlegenai.ImageMode,
googlegenai.TextMode,
},
}),
ai.WithPrompt(fmt.Sprintf(`generate a story about %s and for each scene, generate an image for it`, input)))
if err != nil {
return "", err
}

story := ""
scene := 0
for _, p := range resp.Message.Content {
if p.IsMedia() {
scene += 1
err = base64toFile(p.Text, fmt.Sprintf("scene_%d.png", scene))
}
if p.IsText() {
story += p.Text
}
}
if err != nil {
return "", err
}

return story, nil
})

<-ctx.Done()
}

func base64toFile(data, path string) error {
dec, err := base64.StdEncoding.DecodeString(data)
if err != nil {
return err
}
f, err := os.Create(path)
if err != nil {
return err
}
defer f.Close()

_, err = f.Write(dec)
if err != nil {
return err
}

return f.Sync()
}
Loading