Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions go/ai/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ const (
type GenerateActionOptions struct {
Config any `json:"config,omitempty"`
Docs []*Document `json:"docs,omitempty"`
LongRunning bool `json:"longRunning,omitempty"`
MaxTurns int `json:"maxTurns,omitempty"`
Messages []*Message `json:"messages,omitempty"`
Model string `json:"model,omitempty"`
Expand Down Expand Up @@ -259,6 +260,8 @@ type ModelResponse struct {
// LatencyMs is the time the request took in milliseconds.
LatencyMs float64 `json:"latencyMs,omitempty"`
Message *Message `json:"message,omitempty"`
// Operation holds the background operation details for long-running operations.
Operation map[string]any `json:"operation,omitempty"`
// Request is the [ModelRequest] struct used to trigger this response.
Request *ModelRequest `json:"request,omitempty"`
// Usage describes how many resources were used by this generation request.
Expand Down
147 changes: 101 additions & 46 deletions go/ai/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,11 @@ func GenerateWithRequest(ctx context.Context, r *registry.Registry, opts *Genera
}
model, _ := m.(*model)

// Validate long-running support if requested
if opts.LongRunning && !SupportsLongRunning(r, opts.Model) {
return nil, core.NewError(core.INVALID_ARGUMENT, "model %q does not support long-running operations", opts.Model)
}

resumeOutput, err := handleResumeOption(ctx, r, opts)
if err != nil {
return nil, err
Expand Down Expand Up @@ -338,7 +343,49 @@ func GenerateWithRequest(ctx context.Context, r *registry.Registry, opts *Genera
Output: &outputCfg,
}

fn := core.ChainMiddleware(mw...)(model.Generate)
var fn ModelFunc
if opts.LongRunning {
provider, name, _ := strings.Cut(opts.Model, "/")
bgAction := LookupBackgroundModel(r, provider, name)
if bgAction == nil {
return nil, core.NewError(core.NOT_FOUND, "background model %q not found", opts.Model)
}

// Create a wrapper function that calls the background model but returns a ModelResponse with operation
fn = func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) {
op, err := bgAction.Start(ctx, req)
if err != nil {
return nil, err
}

// Return response with operation
operationMap := map[string]any{
"action": op.Action,
"id": op.ID,
"done": op.Done,
}
if op.Output != nil {
operationMap["output"] = op.Output
}
if op.Error != nil {
operationMap["error"] = map[string]any{
"message": op.Error.Error(),
}
}
if op.Metadata != nil {
operationMap["metadata"] = op.Metadata
}

return &ModelResponse{
Operation: operationMap,
Request: req,
}, nil
}
} else {
fn = model.Generate
}

fn = core.ChainMiddleware(mw...)(fn)

currentTurn := 0
for {
Expand All @@ -347,6 +394,11 @@ func GenerateWithRequest(ctx context.Context, r *registry.Registry, opts *Genera
return nil, err
}

// If this is a long-running operation response, return it immediately without further processing
if resp.Operation != nil {
return resp, nil
}

if formatHandler != nil {
resp.Message, err = formatHandler.ParseMessage(resp.Message)
if err != nil {
Expand Down Expand Up @@ -491,6 +543,7 @@ func Generate(ctx context.Context, r *registry.Registry, opts ...GenerateOption)
Config: genOpts.Config,
ToolChoice: genOpts.ToolChoice,
Docs: genOpts.Documents,
LongRunning: genOpts.LongRunning,
ReturnToolRequests: genOpts.ReturnToolRequests != nil && *genOpts.ReturnToolRequests,
Output: &GenerateActionOutputConfig{
JsonSchema: genOpts.OutputSchema,
Expand Down Expand Up @@ -1172,66 +1225,68 @@ func SupportsLongRunning(r *registry.Registry, modelName string) bool {
return longRunning
}

// GenerateOperation generates a model response as a long-running operation based on the provided options.
// It returns an error if the model does not support long-running operations.
// GenerateOperation generates a model response as a long-running operation based on the provided options. s.
func GenerateOperation(ctx context.Context, r *registry.Registry, opts ...GenerateOption) (*core.Operation[*ModelResponse], error) {
genOpts := &generateOptions{}
for _, opt := range opts {
if err := opt.applyGenerate(genOpts); err != nil {
return nil, core.NewError(core.INVALID_ARGUMENT, "ai.Generate: error applying options: %v", err)
}

opts = append(opts, WithLongRunning())

resp, err := Generate(ctx, r, opts...)
if err != nil {
return nil, err
}

if !SupportsLongRunning(r, genOpts.ModelName) {
return nil, core.NewError(core.INVALID_ARGUMENT, "model %q does not support long-running operations", genOpts.ModelName)
if resp.Operation == nil {
return nil, core.NewError(core.FAILED_PRECONDITION, "model did not return an operation")
}

var modelName string
if genOpts.Model != nil {
modelName = genOpts.Model.Name()
var action string
if v, ok := resp.Operation["action"].(string); ok {
action = v
} else {
modelName = genOpts.ModelName
return nil, core.NewError(core.INTERNAL, "operation missing or invalid 'action' field")
}

provider, name, _ := strings.Cut(modelName, "/")
bgAction := LookupBackgroundModel(r, provider, name)
if bgAction == nil {
return nil, core.NewError(core.NOT_FOUND, "background model %q not found", modelName)
var id string
if v, ok := resp.Operation["id"].(string); ok {
id = v
} else {
return nil, core.NewError(core.INTERNAL, "operation missing or invalid 'id' field")
}

var messages []*Message
if genOpts.SystemFn != nil {
system, err := genOpts.SystemFn(ctx, nil)
if err != nil {
return nil, err
}

messages = append(messages, NewSystemTextMessage(system))
var done bool
if v, ok := resp.Operation["done"].(bool); ok {
done = v
} else {
return nil, core.NewError(core.INTERNAL, "operation missing or invalid 'done' field")
}
if genOpts.MessagesFn != nil {
msgs, err := genOpts.MessagesFn(ctx, nil)
if err != nil {
return nil, err
}

messages = append(messages, msgs...)
var metadata map[string]any
if v, ok := resp.Operation["metadata"].(map[string]any); ok {
metadata = v
}
if genOpts.PromptFn != nil {
prompt, err := genOpts.PromptFn(ctx, nil)
if err != nil {
return nil, err
}

messages = append(messages, NewUserTextMessage(prompt))
op := &core.Operation[*ModelResponse]{
Action: action,
ID: id,
Done: done,
Metadata: metadata,
}

if modelRef, ok := genOpts.Model.(ModelRef); ok && genOpts.Config == nil {
genOpts.Config = modelRef.Config()
if op.Done {
if output, ok := resp.Operation["output"]; ok {
if modelResp, ok := output.(*ModelResponse); ok {
op.Output = modelResp
} else {
op.Output = resp
}
} else {
op.Output = resp
}
}

op, err := bgAction.Start(ctx, &ModelRequest{Messages: messages, Config: genOpts.Config})
if err != nil {
return nil, err
if errorData, ok := resp.Operation["error"]; ok {
if errorMap, ok := errorData.(map[string]any); ok {
if message, ok := errorMap["message"].(string); ok {
op.Error = errors.New(message)
}
}
}

return op, nil
Expand Down
10 changes: 10 additions & 0 deletions go/ai/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,7 @@ type generateOptions struct {
documentOptions
RespondParts []*Part // Tool responses to return from interrupted tool calls.
RestartParts []*Part // Tool requests to restart interrupted tools with.
LongRunning bool // Whether to use long-running operation mode.
}

// GenerateOption is an option for generating a model response. It applies only to Generate().
Expand Down Expand Up @@ -770,9 +771,18 @@ func (o *generateOptions) applyGenerate(genOpts *generateOptions) error {
genOpts.RestartParts = o.RestartParts
}

if o.LongRunning {
genOpts.LongRunning = true
}

return nil
}

// WithLongRunning sets the generation to use long-running operation mode.
func WithLongRunning() GenerateOption {
return &generateOptions{LongRunning: true}
}

// WithToolResponses sets the tool responses to return from interrupted tool calls.
func WithToolResponses(parts ...*Part) GenerateOption {
return &generateOptions{RespondParts: parts}
Expand Down
91 changes: 90 additions & 1 deletion go/samples/veo/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,15 @@ import (
"context"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"os"
"strings"
"time"

"github.com/firebase/genkit/go/ai"
"github.com/firebase/genkit/go/core"
"github.com/firebase/genkit/go/genkit"
"github.com/firebase/genkit/go/plugins/googlegenai"
"google.golang.org/genai"
Expand Down Expand Up @@ -87,9 +92,93 @@ func main() {
currentOp = updatedOp
}

// Operation completed, return the final result
if currentOp != nil {
opJson, _ := json.Marshal(currentOp.Output)
fmt.Printf("%s", opJson)

// Download the generated video
if err := downloadGeneratedVideo(ctx, currentOp); err != nil {
log.Printf("Failed to download video: %v", err)
} else {
fmt.Println("Video successfully downloaded to veo3_video.mp4")
}
}
}

// For testing purpose need to be removed if needed
// downloadGeneratedVideo downloads the generated video from the operation result
func downloadGeneratedVideo(ctx context.Context, operation *core.Operation[*ai.ModelResponse]) error {
// Get the API key from environment
apiKey := os.Getenv("GEMINI_API_KEY")
if apiKey == "" {
apiKey = os.Getenv("GOOGLE_API_KEY")
}
if apiKey == "" {
return fmt.Errorf("no API key found. Please set GEMINI_API_KEY or GOOGLE_API_KEY environment variable")
}

// Parse the operation output to extract video URL
if operation.Output == nil {
return fmt.Errorf("operation output is nil")
}

modelResponse := operation.Output
if modelResponse.Message == nil {
return fmt.Errorf("model response message is nil")
}

// Find the media part in the message content
var videoURL string
for _, part := range modelResponse.Message.Content {
if part.IsMedia() && part.Text != "" {
videoURL = part.Text
break
}
}

if videoURL == "" {
return fmt.Errorf("no video URL found in the operation output")
}

// Append API key to the URL if it's not already there
downloadURL := videoURL
if !strings.Contains(downloadURL, "key=") {
separator := "&"
if !strings.Contains(downloadURL, "?") {
separator = "?"
}
downloadURL = fmt.Sprintf("%s%skey=%s", videoURL, separator, apiKey)
}

req, err := http.NewRequestWithContext(ctx, "GET", downloadURL, nil)
if err != nil {
return fmt.Errorf("failed to create HTTP request: %v", err)
}

client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("failed to download video: %v", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return fmt.Errorf("failed to download video: HTTP %d", resp.StatusCode)
}

// Create the output file
filename := "veo3_video.mp4"
file, err := os.Create(filename)
if err != nil {
return fmt.Errorf("failed to create output file: %v", err)
}
defer file.Close()

// Copy the video content to the file
_, err = io.Copy(file, resp.Body)
if err != nil {
return fmt.Errorf("failed to write video data to file: %v", err)
}

return nil
}
Loading