Skip to content

Commit 0bde5eb

Browse files
hugoaguirrehendrixmar
authored andcommitted
fix(go/plugins/googlegenai): ignore response mime type if tools are present (#2794)
1 parent 518e924 commit 0bde5eb

File tree

3 files changed

+130
-9
lines changed

3 files changed

+130
-9
lines changed

go/plugins/googlegenai/gemini.go

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ func generate(
340340
return nil, err
341341
}
342342

343-
gcc, err := convertRequest(input, cache)
343+
gcc, err := toGeminiRequest(input, cache)
344344
if err != nil {
345345
return nil, err
346346
}
@@ -448,9 +448,9 @@ func generate(
448448
return r, nil
449449
}
450450

451-
// convertRequest translates from [*ai.ModelRequest] to
451+
// toGeminiRequest translates from [*ai.ModelRequest] to
452452
// *genai.GenerateContentParameters
453-
func convertRequest(input *ai.ModelRequest, cache *genai.CachedContent) (*genai.GenerateContentConfig, error) {
453+
func toGeminiRequest(input *ai.ModelRequest, cache *genai.CachedContent) (*genai.GenerateContentConfig, error) {
454454
gcc := genai.GenerateContentConfig{
455455
CandidateCount: 1,
456456
}
@@ -483,8 +483,9 @@ func convertRequest(input *ai.ModelRequest, cache *genai.CachedContent) (*genai.
483483
hasOutput := input.Output != nil
484484
isJsonFormat := hasOutput && input.Output.Format == "json"
485485
isJsonContentType := hasOutput && input.Output.ContentType == "application/json"
486-
jsonMode := isJsonFormat || (isJsonContentType && len(input.Tools) == 0)
487-
if jsonMode {
486+
jsonMode := isJsonFormat || isJsonContentType
487+
// this setting is not compatible with tools forcing controlled output generation
488+
if jsonMode && len(input.Tools) == 0 {
488489
gcc.ResponseMIMEType = "application/json"
489490
}
490491

@@ -507,7 +508,7 @@ func convertRequest(input *ai.ModelRequest, cache *genai.CachedContent) (*genai.
507508
gcc.Tools = tools
508509

509510
// Then set up the tool configuration based on ToolChoice
510-
tc, err := convertToolChoice(input.ToolChoice, input.Tools)
511+
tc, err := toGeminiToolChoice(input.ToolChoice, input.Tools)
511512
if err != nil {
512513
return nil, err
513514
}
@@ -708,7 +709,7 @@ func castToStringArray(i []any) []string {
708709
return r
709710
}
710711

711-
func convertToolChoice(toolChoice ai.ToolChoice, tools []*ai.ToolDefinition) (*genai.ToolConfig, error) {
712+
func toGeminiToolChoice(toolChoice ai.ToolChoice, tools []*ai.ToolDefinition) (*genai.ToolConfig, error) {
712713
var mode genai.FunctionCallingConfigMode
713714
switch toolChoice {
714715
case "":

go/plugins/googlegenai/gemini_test.go

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ import (
2222
"github.com/firebase/genkit/go/ai"
2323
)
2424

25-
func TestGeminiTools(t *testing.T) {
25+
func TestConvertRequest(t *testing.T) {
26+
text := "hello"
2627
tool := &ai.ToolDefinition{
2728
Description: "this is a dummy tool",
2829
Name: "myTool",
@@ -35,6 +36,92 @@ func TestGeminiTools(t *testing.T) {
3536
OutputSchema: map[string]any{"type": string("string")},
3637
}
3738

39+
req := &ai.ModelRequest{
40+
Config: GeminiConfig{
41+
MaxOutputTokens: 10,
42+
StopSequences: []string{"stop"},
43+
Temperature: 0.4,
44+
TopK: 1.0,
45+
TopP: 1.0,
46+
Version: text,
47+
},
48+
Tools: []*ai.ToolDefinition{tool},
49+
ToolChoice: ai.ToolChoiceAuto,
50+
Output: &ai.ModelOutputConfig{
51+
Constrained: true,
52+
Schema: map[string]any{"type": string("string")},
53+
},
54+
Messages: []*ai.Message{
55+
{
56+
Role: ai.RoleUser,
57+
Content: []*ai.Part{
58+
{Text: text},
59+
},
60+
},
61+
{
62+
Role: ai.RoleSystem,
63+
Content: []*ai.Part{
64+
{Text: text},
65+
},
66+
},
67+
{
68+
Role: ai.RoleUser,
69+
Content: []*ai.Part{
70+
{Text: text},
71+
},
72+
},
73+
{
74+
Role: ai.RoleSystem,
75+
Content: []*ai.Part{
76+
{Text: text},
77+
},
78+
},
79+
},
80+
}
81+
t.Run("convert request", func(t *testing.T) {
82+
gcc, err := toGeminiRequest(req, nil)
83+
if err != nil {
84+
t.Fatal(err)
85+
}
86+
if gcc.SystemInstruction == nil {
87+
t.Error("expecting system instructions to be populated")
88+
}
89+
if len(gcc.SystemInstruction.Parts) != 2 {
90+
t.Errorf("got: %d, want: 2", len(gcc.SystemInstruction.Parts))
91+
}
92+
if gcc.SystemInstruction.Role != string(ai.RoleSystem) {
93+
t.Errorf(" system instruction role: got: %q, want: %q", gcc.SystemInstruction.Role, string(ai.RoleSystem))
94+
}
95+
// this is explicitly set to 1 in source
96+
if gcc.CandidateCount == nil {
97+
t.Error("candidate count: got: nil, want: 1")
98+
}
99+
ogCfg, ok := req.Config.(GeminiConfig)
100+
if !ok {
101+
t.Fatalf("request config should have been of type: GeminiConfig, got: %T", req.Config)
102+
}
103+
if gcc.MaxOutputTokens == nil {
104+
t.Errorf("max output tokens: got: nil, want %d", ogCfg.MaxOutputTokens)
105+
}
106+
if len(gcc.StopSequences) == 0 {
107+
t.Errorf("stop sequences: got: 0, want: %d", len(ogCfg.StopSequences))
108+
}
109+
if gcc.Temperature == nil {
110+
t.Errorf("temperature: got: nil, want %f", ogCfg.Temperature)
111+
}
112+
if gcc.TopP == nil {
113+
t.Errorf("topP: got: nil, want %f", ogCfg.TopP)
114+
}
115+
if gcc.TopK == nil {
116+
t.Errorf("topK: got: nil, want %d", ogCfg.TopK)
117+
}
118+
if gcc.ResponseMIMEType != "" {
119+
t.Errorf("ResponseMIMEType should been empty if tools are present")
120+
}
121+
if gcc.ResponseSchema == nil {
122+
t.Errorf("ResponseSchema should not be empty")
123+
}
124+
})
38125
t.Run("convert tools with valid tool", func(t *testing.T) {
39126
tools := []*ai.ToolDefinition{tool}
40127
gt, err := toGeminiTools(tools)

go/plugins/googlegenai/googleai_live_test.go

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"context"
2121
"encoding/base64"
2222
"flag"
23+
"fmt"
2324
"io"
2425
"log"
2526
"math"
@@ -67,7 +68,7 @@ func TestGoogleAILive(t *testing.T) {
6768
t.Fatal(err)
6869
}
6970

70-
gablorkenTool := genkit.DefineTool(g, "gablorken", "use when need to calculate a gablorken",
71+
gablorkenTool := genkit.DefineTool(g, "gablorken", "use this tool when the user asks to calculate a gablorken",
7172
func(ctx *ai.ToolContext, input struct {
7273
Value int
7374
Over float64
@@ -165,7 +166,39 @@ func TestGoogleAILive(t *testing.T) {
165166
t.Errorf("got %q, expecting it to contain %q", out, want)
166167
}
167168
})
169+
t.Run("tool with json output", func(t *testing.T) {
170+
type weatherQuery struct {
171+
Location string `json:"location"`
172+
}
173+
174+
type weather struct {
175+
Report string `json:"report"`
176+
}
177+
178+
weatherTool := genkit.DefineTool(g, "weatherTool",
179+
"Use this tool to get the weather report for a specific location",
180+
func(ctx *ai.ToolContext, input weatherQuery) (string, error) {
181+
report := fmt.Sprintf("The weather in %s is sunny and 70 degrees today.", input.Location)
182+
return report, nil
183+
},
184+
)
168185

186+
resp, err := genkit.Generate(ctx, g,
187+
ai.WithTools(weatherTool),
188+
ai.WithPrompt("what's the weather in San Francisco?"),
189+
ai.WithOutputType(weather{}),
190+
)
191+
if err != nil {
192+
t.Fatal(err)
193+
}
194+
var w weather
195+
if err = resp.Output(&w); err != nil {
196+
t.Fatal(err)
197+
}
198+
if w.Report == "" {
199+
t.Fatal("empty weather report, tool should have provided an output")
200+
}
201+
})
169202
t.Run("avoid tool", func(t *testing.T) {
170203
resp, err := genkit.Generate(ctx, g,
171204
ai.WithPrompt("what is a gablorken of 2 over 3.5?"),

0 commit comments

Comments
 (0)