diff --git a/firebase_test.go b/firebase_test.go index 1b367f1c..f29bb73b 100644 --- a/firebase_test.go +++ b/firebase_test.go @@ -29,6 +29,7 @@ import ( "testing" "time" + "firebase.google.com/go/messaging" "golang.org/x/oauth2" "golang.org/x/oauth2/google" "google.golang.org/api/option" @@ -361,6 +362,43 @@ func TestMessaging(t *testing.T) { } } +func TestMessagingSendWithCustomEndpoint(t *testing.T) { + name := "custom-endpoint-ok" + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte("{ \"name\":\"" + name + "\" }")) + })) + defer ts.Close() + + ctx := context.Background() + + tokenSource := &testTokenSource{AccessToken: "mock-token-from-custom"} + app, err := NewApp( + ctx, + nil, + option.WithCredentialsFile("testdata/service_account.json"), + option.WithTokenSource(tokenSource), + option.WithEndpoint(ts.URL), + ) + if err != nil { + t.Fatal(err) + } + + c, err := app.Messaging(ctx) + if c == nil || err != nil { + t.Fatalf("Messaging() = (%v, %v); want (iid, nil)", c, err) + } + + msg := &messaging.Message{ + Token: "...", + } + n, err := c.Send(ctx, msg) + if n != name || err != nil { + t.Errorf("Send() = (%q, %v); want (%q, nil)", n, err, name) + } +} + func TestCustomTokenSource(t *testing.T) { ctx := context.Background() ts := &testTokenSource{AccessToken: "mock-token-from-custom"} diff --git a/messaging/messaging.go b/messaging/messaging.go index 37f61b00..aabfce3d 100644 --- a/messaging/messaging.go +++ b/messaging/messaging.go @@ -913,13 +913,17 @@ func NewClient(ctx context.Context, c *internal.MessagingConfig) (*Client, error return nil, errors.New("project ID is required to access Firebase Cloud Messaging client") } - hc, _, err := transport.NewHTTPClient(ctx, c.Opts...) + hc, endpoint, err := transport.NewHTTPClient(ctx, c.Opts...) if err != nil { return nil, err } + if endpoint == "" { + endpoint = messagingEndpoint + } + return &Client{ - fcmClient: newFCMClient(hc, c), + fcmClient: newFCMClient(hc, c, endpoint), iidClient: newIIDClient(hc), }, nil } @@ -932,7 +936,7 @@ type fcmClient struct { httpClient *internal.HTTPClient } -func newFCMClient(hc *http.Client, conf *internal.MessagingConfig) *fcmClient { +func newFCMClient(hc *http.Client, conf *internal.MessagingConfig, endpoint string) *fcmClient { client := internal.WithDefaultRetryConfig(hc) client.CreateErrFn = handleFCMError client.SuccessFn = internal.HasSuccessStatus @@ -944,7 +948,7 @@ func newFCMClient(hc *http.Client, conf *internal.MessagingConfig) *fcmClient { } return &fcmClient{ - fcmEndpoint: messagingEndpoint, + fcmEndpoint: endpoint, batchEndpoint: batchEndpoint, project: conf.ProjectID, version: version, diff --git a/messaging/messaging_test.go b/messaging/messaging_test.go index 97f03205..be966513 100644 --- a/messaging/messaging_test.go +++ b/messaging/messaging_test.go @@ -1126,6 +1126,43 @@ func TestSend(t *testing.T) { } } +func TestSendWithCustomEndpoint(t *testing.T) { + var tr *http.Request + var b []byte + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + tr = r + b, _ = ioutil.ReadAll(r.Body) + w.Header().Set("Content-Type", "application/json") + w.Write([]byte("{ \"name\":\"" + testMessageID + "\" }")) + })) + defer ts.Close() + + ctx := context.Background() + + conf := *testMessagingConfig + optEndpoint := option.WithEndpoint(ts.URL) + conf.Opts = append(conf.Opts, optEndpoint) + + client, err := NewClient(ctx, &conf) + if err != nil { + t.Fatal(err) + } + + if ts.URL != client.fcmEndpoint { + t.Errorf("client.fcmEndpoint = %q; want = %q", client.fcmEndpoint, ts.URL) + } + + for _, tc := range validMessages { + t.Run(tc.name, func(t *testing.T) { + name, err := client.Send(ctx, tc.req) + if name != testMessageID || err != nil { + t.Errorf("Send(%s) = (%q, %v); want = (%q, nil)", tc.name, name, err, testMessageID) + } + checkFCMRequest(t, b, tr, tc.want, false) + }) + } +} + func TestSendDryRun(t *testing.T) { var tr *http.Request var b []byte