diff --git a/messaging/messaging.go b/messaging/messaging.go index 08335330..5fcfe168 100644 --- a/messaging/messaging.go +++ b/messaging/messaging.go @@ -32,8 +32,8 @@ import ( ) const ( - messagingEndpoint = "https://fcm.googleapis.com/v1" - batchEndpoint = "https://fcm.googleapis.com/batch" + defaultMessagingEndpoint = "https://fcm.googleapis.com/v1" + defaultBatchEndpoint = "https://fcm.googleapis.com/batch" firebaseClientHeader = "X-Firebase-Client" apiFormatVersionHeader = "X-GOOG-API-FORMAT-VERSION" @@ -862,17 +862,20 @@ func NewClient(ctx context.Context, c *internal.MessagingConfig) (*Client, error return nil, errors.New("project ID is required to access Firebase Cloud Messaging client") } - hc, endpoint, err := transport.NewHTTPClient(ctx, c.Opts...) + hc, messagingEndpoint, err := transport.NewHTTPClient(ctx, c.Opts...) if err != nil { return nil, err } - if endpoint == "" { - endpoint = messagingEndpoint + batchEndpoint := messagingEndpoint + + if messagingEndpoint == "" { + messagingEndpoint = defaultMessagingEndpoint + batchEndpoint = defaultBatchEndpoint } return &Client{ - fcmClient: newFCMClient(hc, c, endpoint), + fcmClient: newFCMClient(hc, c, messagingEndpoint, batchEndpoint), iidClient: newIIDClient(hc), }, nil } @@ -885,7 +888,7 @@ type fcmClient struct { httpClient *internal.HTTPClient } -func newFCMClient(hc *http.Client, conf *internal.MessagingConfig, endpoint string) *fcmClient { +func newFCMClient(hc *http.Client, conf *internal.MessagingConfig, messagingEndpoint string, batchEndpoint string) *fcmClient { client := internal.WithDefaultRetryConfig(hc) client.CreateErrFn = handleFCMError @@ -896,7 +899,7 @@ func newFCMClient(hc *http.Client, conf *internal.MessagingConfig, endpoint stri } return &fcmClient{ - fcmEndpoint: endpoint, + fcmEndpoint: messagingEndpoint, batchEndpoint: batchEndpoint, project: conf.ProjectID, version: version, diff --git a/messaging/messaging_batch_test.go b/messaging/messaging_batch_test.go index 8804342e..219ce2ce 100644 --- a/messaging/messaging_batch_test.go +++ b/messaging/messaging_batch_test.go @@ -27,6 +27,8 @@ import ( "net/http/httptest" "net/textproto" "testing" + + "google.golang.org/api/option" ) var testMessages = []*Message{ @@ -520,6 +522,46 @@ func TestSendMulticast(t *testing.T) { } } +func TestSendMulticastWithCustomEndpoint(t *testing.T) { + resp, err := createMultipartResponse(testSuccessResponse, nil) + if err != nil { + t.Fatal(err) + } + + var req []byte + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + req, _ = ioutil.ReadAll(r.Body) + w.Header().Set("Content-Type", wantMime) + w.Write(resp) + })) + defer ts.Close() + + ctx := context.Background() + + conf := *testMessagingConfig + customBatchEndpoint := fmt.Sprintf("%s/v1", ts.URL) + optEndpoint := option.WithEndpoint(customBatchEndpoint) + conf.Opts = append(conf.Opts, optEndpoint) + + client, err := NewClient(ctx, &conf) + if err != nil { + t.Fatal(err) + } + + if customBatchEndpoint != client.batchEndpoint { + t.Errorf("client.batchEndpoint = %q; want = %q", client.batchEndpoint, customBatchEndpoint) + } + + br, err := client.SendMulticast(ctx, testMulticastMessage) + if err != nil { + t.Fatal(err) + } + + if err := checkSuccessfulBatchResponse(br, req, false); err != nil { + t.Errorf("SendMulticast() = %v", err) + } +} + func TestSendMulticastDryRun(t *testing.T) { resp, err := createMultipartResponse(testSuccessResponse, nil) if err != nil {