diff --git a/auth/user_mgt.go b/auth/user_mgt.go index 88259634..e4ada620 100644 --- a/auth/user_mgt.go +++ b/auth/user_mgt.go @@ -609,6 +609,47 @@ func (c *baseClient) GetUserByPhoneNumber(ctx context.Context, phone string) (*U }) } +// GetUserByProviderID gets the user data for the user corresponding to a given provider ID. +// +// See +// [Retrieve user data](https://firebase.google.com/docs/auth/admin/manage-users#retrieve_user_data) +// for code samples and detailed documentation. +// +// `providerID` indicates the provider, such as 'google.com' for the Google provider. +// `providerUID` is the user identifier for the given provider. +func (c *baseClient) GetUserByProviderID(ctx context.Context, providerID string, providerUID string) (*UserRecord, error) { + // Although we don't really advertise it, we want to also handle non-federated + // IDPs with this call. So if we detect one of them, we'll reroute this + // request appropriately. + if providerID == "phone" { + return c.GetUserByPhoneNumber(ctx, providerUID) + } else if providerID == "email" { + return c.GetUserByEmail(ctx, providerUID) + } + + if err := validateProvider(providerID, providerUID); err != nil { + return nil, err + } + + getUsersResult, err := c.GetUsers(ctx, []UserIdentifier{&ProviderIdentifier{providerID, providerUID}}) + if err != nil { + return nil, err + } + + if len(getUsersResult.Users) == 0 { + return nil, &internal.FirebaseError{ + ErrorCode: internal.NotFound, + String: fmt.Sprintf("cannot find user from providerID: { %s, %s }", providerID, providerUID), + Response: nil, + Ext: map[string]interface{}{ + authErrorCode: userNotFound, + }, + } + } + + return getUsersResult.Users[0], nil +} + type userQuery struct { field string value string diff --git a/auth/user_mgt_test.go b/auth/user_mgt_test.go index 842363b0..32b2e726 100644 --- a/auth/user_mgt_test.go +++ b/auth/user_mgt_test.go @@ -141,22 +141,96 @@ func TestGetUserByPhoneNumber(t *testing.T) { } } +func TestGetUserByProviderIDNotFound(t *testing.T) { + mockUsers := []byte(`{ "users": [] }`) + s := echoServer(mockUsers, t) + defer s.Close() + + userRecord, err := s.Client.GetUserByProviderID(context.Background(), "google.com", "google_uid1") + want := "cannot find user from providerID: { google.com, google_uid1 }" + if userRecord != nil || err == nil || err.Error() != want || !IsUserNotFound(err) { + t.Errorf("GetUserByProviderID() = (%v, %q); want = (nil, %q)", userRecord, err, want) + } +} + +func TestGetUserByProviderId(t *testing.T) { + cases := []struct { + providerID string + providerUID string + want string + }{ + { + "google.com", + "google_uid1", + `{"federatedUserId":[{"providerId":"google.com","rawId":"google_uid1"}]}`, + }, { + "phone", + "+15555550001", + `{"phoneNumber":["+15555550001"]}`, + }, { + "email", + "user@example.com", + `{"email":["user@example.com"]}`, + }, + } + + // The resulting user isn't parsed, so it just needs to exist (even if it's empty). + mockUsers := []byte(`{ "users": [{}] }`) + s := echoServer(mockUsers, t) + defer s.Close() + + for _, tc := range cases { + t.Run(tc.providerID+":"+tc.providerUID, func(t *testing.T) { + + _, err := s.Client.GetUserByProviderID(context.Background(), tc.providerID, tc.providerUID) + if err != nil { + t.Fatalf("GetUserByProviderID() = %q", err) + } + + got := string(s.Rbody) + if got != tc.want { + t.Errorf("GetUserByProviderID() Req = %v; want = %v", got, tc.want) + } + + wantPath := "/projects/mock-project-id/accounts:lookup" + if s.Req[0].RequestURI != wantPath { + t.Errorf("GetUserByProviderID() URL = %q; want = %q", s.Req[0].RequestURI, wantPath) + } + }) + } +} + func TestInvalidGetUser(t *testing.T) { client := &Client{ baseClient: &baseClient{}, } + user, err := client.GetUser(context.Background(), "") if user != nil || err == nil { t.Errorf("GetUser('') = (%v, %v); want = (nil, error)", user, err) } + user, err = client.GetUserByEmail(context.Background(), "") if user != nil || err == nil { t.Errorf("GetUserByEmail('') = (%v, %v); want = (nil, error)", user, err) } + user, err = client.GetUserByPhoneNumber(context.Background(), "") if user != nil || err == nil { t.Errorf("GetUserPhoneNumber('') = (%v, %v); want = (nil, error)", user, err) } + + userRecord, err := client.GetUserByProviderID(context.Background(), "", "google_uid1") + want := "providerID must be a non-empty string" + if userRecord != nil || err == nil || err.Error() != want { + t.Errorf("GetUserByProviderID() = (%v, %q); want = (nil, %q)", userRecord, err, want) + } + + userRecord, err = client.GetUserByProviderID(context.Background(), "google.com", "") + want = "providerUID must be a non-empty string" + if userRecord != nil || err == nil || err.Error() != want { + t.Errorf("GetUserByProviderID() = (%v, %q); want = (nil, %q)", userRecord, err, want) + } } // Checks to see if the users list contain the given uids. Order is ignored. diff --git a/integration/auth/user_mgt_test.go b/integration/auth/user_mgt_test.go index e04fb680..13a27ddf 100644 --- a/integration/auth/user_mgt_test.go +++ b/integration/auth/user_mgt_test.go @@ -80,6 +80,34 @@ func TestGetUser(t *testing.T) { } } +func TestGetUserByProviderID(t *testing.T) { + // TODO(rsgowman): Once we can link a provider id with a user, just do that + // here instead of importing a new user. + importUserUID := randomUID() + providerUID := "google_" + importUserUID + userToImport := (&auth.UserToImport{}). + UID(importUserUID). + Email(randomEmail(importUserUID)). + PhoneNumber(randomPhoneNumber()). + ProviderData([](*auth.UserProvider){ + &auth.UserProvider{ + ProviderID: "google.com", + UID: providerUID, + }, + }) + importUser(t, importUserUID, userToImport) + defer deleteUser(importUserUID) + + userRecord, err := client.GetUserByProviderID(context.Background(), "google.com", providerUID) + if err != nil { + t.Fatalf("GetUserByProviderID() = %q", err) + } + + if userRecord.UID != importUserUID { + t.Errorf("GetUserByProviderID().UID = %v; want = %v", userRecord.UID, importUserUID) + } +} + func TestGetNonExistingUser(t *testing.T) { user, err := client.GetUser(context.Background(), "non.existing") if user != nil || !auth.IsUserNotFound(err) { @@ -90,6 +118,16 @@ func TestGetNonExistingUser(t *testing.T) { if user != nil || !auth.IsUserNotFound(err) { t.Errorf("GetUserByEmail(non.existing) = (%v, %v); want = (nil, error)", user, err) } + + user, err = client.GetUserByPhoneNumber(context.Background(), "+14044040404") + if user != nil || !auth.IsUserNotFound(err) { + t.Errorf("GetUser(non.existing) = (%v, %v); want = (nil, error)", user, err) + } + + user, err = client.GetUserByProviderID(context.Background(), "google.com", "a-uid-that-doesnt-exist") + if user != nil || !auth.IsUserNotFound(err) { + t.Errorf("GetUser(non.existing) = (%v, %v); want = (nil, error)", user, err) + } } func TestGetUsers(t *testing.T) {