Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
260 changes: 108 additions & 152 deletions auth/user_mgt.go
Original file line number Diff line number Diff line change
Expand Up @@ -570,8 +570,7 @@ type getAccountInfoResponse struct {

func (c *baseClient) getUser(ctx context.Context, query *userQuery) (*UserRecord, error) {
var parsed getAccountInfoResponse
_, err := c.post(ctx, "/accounts:lookup", query.build(), &parsed)
if err != nil {
if _, err := c.post(ctx, "/accounts:lookup", query.build(), &parsed); err != nil {
return nil, err
}

Expand All @@ -584,10 +583,8 @@ func (c *baseClient) getUser(ctx context.Context, query *userQuery) (*UserRecord

// A UserIdentifier identifies a user to be looked up.
type UserIdentifier interface {
String() string
toString() string

validate() error
matchesUserRecord(ur *UserRecord) (bool, error)
populateRequest(req *getAccountInfoRequest) error
}

Expand All @@ -598,31 +595,12 @@ type UIDIdentifier struct {
UID string
}

func (id UIDIdentifier) String() string {
func (id UIDIdentifier) toString() string {
return fmt.Sprintf("UIDIdentifier{%s}", id.UID)
}

func (id UIDIdentifier) validate() error {
return validateUID(id.UID)
}

func (id UIDIdentifier) matchesUserRecord(ur *UserRecord) (bool, error) {
err := id.validate()
if err != nil {
return false, err
}
if id.UID == ur.UID {
return true, nil
}
return false, nil
}

func (id UIDIdentifier) populateRequest(req *getAccountInfoRequest) error {
err := id.validate()
if err != nil {
return err
}
req.localID = append(req.localID, id.UID)
req.LocalID = append(req.LocalID, id.UID)
return nil
}

Expand All @@ -633,31 +611,12 @@ type EmailIdentifier struct {
Email string
}

func (id EmailIdentifier) String() string {
func (id EmailIdentifier) toString() string {
return fmt.Sprintf("EmailIdentifier{%s}", id.Email)
}

func (id EmailIdentifier) validate() error {
return validateEmail(id.Email)
}

func (id EmailIdentifier) matchesUserRecord(ur *UserRecord) (bool, error) {
err := id.validate()
if err != nil {
return false, err
}
if id.Email == ur.Email {
return true, nil
}
return false, nil
}

func (id EmailIdentifier) populateRequest(req *getAccountInfoRequest) error {
err := id.validate()
if err != nil {
return err
}
req.email = append(req.email, id.Email)
req.Email = append(req.Email, id.Email)
return nil
}

Expand All @@ -668,31 +627,12 @@ type PhoneIdentifier struct {
PhoneNumber string
}

func (id PhoneIdentifier) String() string {
func (id PhoneIdentifier) toString() string {
return fmt.Sprintf("PhoneIdentifier{%s}", id.PhoneNumber)
}

func (id PhoneIdentifier) validate() error {
return validatePhone(id.PhoneNumber)
}

func (id PhoneIdentifier) matchesUserRecord(ur *UserRecord) (bool, error) {
err := id.validate()
if err != nil {
return false, err
}
if id.PhoneNumber == ur.PhoneNumber {
return true, nil
}
return false, nil
}

func (id PhoneIdentifier) populateRequest(req *getAccountInfoRequest) error {
err := id.validate()
if err != nil {
return err
}
req.phoneNumber = append(req.phoneNumber, id.PhoneNumber)
req.PhoneNumber = append(req.PhoneNumber, id.PhoneNumber)
return nil
}

Expand All @@ -704,35 +644,14 @@ type ProviderIdentifier struct {
ProviderUID string
}

func (id ProviderIdentifier) String() string {
func (id ProviderIdentifier) toString() string {
return fmt.Sprintf("ProviderIdentifier{%s, %s}", id.ProviderID, id.ProviderUID)
}

func (id ProviderIdentifier) validate() error {
return validateProvider(id.ProviderID, id.ProviderUID)
}

func (id ProviderIdentifier) matchesUserRecord(ur *UserRecord) (bool, error) {
err := id.validate()
if err != nil {
return false, err
}
for _, userInfo := range ur.ProviderUserInfo {
if id.ProviderID == userInfo.ProviderID && id.ProviderUID == userInfo.UID {
return true, nil
}
}
return false, nil
}

func (id ProviderIdentifier) populateRequest(req *getAccountInfoRequest) error {
err := id.validate()
if err != nil {
return err
}
req.federatedUserID = append(
req.federatedUserID,
federatedUserIdentifier{providerID: id.ProviderID, rawID: id.ProviderUID})
req.FederatedUserID = append(
req.FederatedUserID,
federatedUserIdentifier{ProviderID: id.ProviderID, RawID: id.ProviderUID})
return nil
}

Expand All @@ -747,43 +666,94 @@ type GetUsersResult struct {
}

type federatedUserIdentifier struct {
providerID string
rawID string
ProviderID string `json:"providerId,omitempty"`
RawID string `json:"rawId,omitempty"`
}

type getAccountInfoRequest struct {
localID []string
email []string
phoneNumber []string
federatedUserID []federatedUserIdentifier
LocalID []string `json:"localId"`
Email []string `json:"email"`
PhoneNumber []string `json:"phoneNumber,omitempty"`
FederatedUserID []federatedUserIdentifier `json:"federatedUserId,omitempty"`
}

func (req *getAccountInfoRequest) build() map[string]interface{} {
var builtFederatedUserID []map[string]interface{}
for i := range req.federatedUserID {
builtFederatedUserID = append(builtFederatedUserID, map[string]interface{}{
"providerId": req.federatedUserID[i].providerID,
"rawId": req.federatedUserID[i].rawID,
})
func (req *getAccountInfoRequest) validate() error {
for i := range req.LocalID {
if err := validateUID(req.LocalID[i]); err != nil {
return err
}
}
return map[string]interface{}{
"localId": req.localID,
"email": req.email,
"phoneNumber": req.phoneNumber,
"federatedUserId": builtFederatedUserID,

for i := range req.Email {
if err := validateEmail(req.Email[i]); err != nil {
return err
}
}

for i := range req.PhoneNumber {
if err := validatePhone(req.PhoneNumber[i]); err != nil {
return err
}
}

for i := range req.FederatedUserID {
id := &req.FederatedUserID[i]
if err := validateProvider(id.ProviderID, id.RawID); err != nil {
return err
}
}

return nil
}

func isUserFound(id UserIdentifier, urs [](*UserRecord)) (bool, error) {
func isUserFound(id UserIdentifier, urs [](*UserRecord)) bool {
for i := range urs {
match, err := id.matchesUserRecord(urs[i])
if err != nil {
return false, err
} else if match {
return true, nil
if uidIdentifier, ok := id.(*UIDIdentifier); ok {
if uidIdentifier.UID == urs[i].UID {
return true
}
}
if uidIdentifier, ok := id.(UIDIdentifier); ok {
if uidIdentifier.UID == urs[i].UID {
return true
}
}
if emailIdentifier, ok := id.(*EmailIdentifier); ok {
if emailIdentifier.Email == urs[i].Email {
return true
}
}
if emailIdentifier, ok := id.(EmailIdentifier); ok {
if emailIdentifier.Email == urs[i].Email {
return true
}
}
if phoneIdentifier, ok := id.(*PhoneIdentifier); ok {
if phoneIdentifier.PhoneNumber == urs[i].PhoneNumber {
return true
}
}
if phoneIdentifier, ok := id.(PhoneIdentifier); ok {
if phoneIdentifier.PhoneNumber == urs[i].PhoneNumber {
return true
}
}
if providerIdentifier, ok := id.(*ProviderIdentifier); ok {
for _, userInfo := range urs[i].ProviderUserInfo {
if providerIdentifier.ProviderID == userInfo.ProviderID && providerIdentifier.ProviderUID == userInfo.UID {
return true
}
}
}
if providerIdentifier, ok := id.(ProviderIdentifier); ok {
for _, userInfo := range urs[i].ProviderUserInfo {
if providerIdentifier.ProviderID == userInfo.ProviderID && providerIdentifier.ProviderUID == userInfo.UID {
return true
}
}
}
}
return false, nil
return false
}

func (c *baseClient) GetUsers(
Expand All @@ -792,22 +762,23 @@ func (c *baseClient) GetUsers(
if len(identifiers) == 0 {
return &GetUsersResult{[](*UserRecord){}, [](UserIdentifier){}}, nil
} else if len(identifiers) > maxGetAccountsBatchSize {
return nil, internal.Errorf(
maximumUserCountExceeded,
"`identifiers` parameter must have <= %d entries.", maxGetAccountsBatchSize)
return nil, fmt.Errorf(
"`identifiers` parameter must have <= %d entries", maxGetAccountsBatchSize)
}

var request getAccountInfoRequest
for i := range identifiers {
err := identifiers[i].populateRequest(&request)
if err != nil {
if err := identifiers[i].populateRequest(&request); err != nil {
return nil, err
}
}

if err := request.validate(); err != nil {
return nil, err
}

var parsed getAccountInfoResponse
_, err := c.post(ctx, "/accounts:lookup", request.build(), &parsed)
if err != nil {
if _, err := c.post(ctx, "/accounts:lookup", request, &parsed); err != nil {
return nil, err
}

Expand All @@ -825,11 +796,7 @@ func (c *baseClient) GetUsers(

var notFound []UserIdentifier
for i := range identifiers {
userFound, err := isUserFound(identifiers[i], userRecords)
if err != nil {
return nil, err
}
if !userFound {
if !isUserFound(identifiers[i], userRecords) {
notFound = append(notFound, identifiers[i])
}
}
Expand Down Expand Up @@ -869,8 +836,7 @@ func (r *userQueryResponse) makeUserRecord() (*UserRecord, error) {
func (r *userQueryResponse) makeExportedUserRecord() (*ExportedUserRecord, error) {
var customClaims map[string]interface{}
if r.CustomAttributes != "" {
err := json.Unmarshal([]byte(r.CustomAttributes), &customClaims)
if err != nil {
if err := json.Unmarshal([]byte(r.CustomAttributes), &customClaims); err != nil {
return nil, err
}
if len(customClaims) == 0 {
Expand Down Expand Up @@ -1037,8 +1003,8 @@ type DeleteUsersResult struct {
// The Index field corresponds to the index of the failed user in the uids
// array that was passed to DeleteUsers().
type DeleteUsersErrorInfo struct {
Index int
Reason string
Index int `json:"index,omitEmpty"`
Reason string `json:"message,omitEmpty"`
}

// Deletes the users specified by the given identifiers.
Expand All @@ -1064,9 +1030,8 @@ func (c *baseClient) DeleteUsers(ctx context.Context, uids []string) (*DeleteUse
if len(uids) == 0 {
return &DeleteUsersResult{}, nil
} else if len(uids) > maxDeleteAccountsBatchSize {
return nil, internal.Errorf(
maximumUserCountExceeded,
"`uids` parameter must have <= %d entries.", maxDeleteAccountsBatchSize)
return nil, fmt.Errorf(
"`uids` parameter must have <= %d entries", maxDeleteAccountsBatchSize)
}

var payload struct {
Expand All @@ -1083,31 +1048,22 @@ func (c *baseClient) DeleteUsers(ctx context.Context, uids []string) (*DeleteUse
payload.LocalIds = append(payload.LocalIds, uids[i])
}

type batchDeleteErrorInfo struct {
Index int `json:"index"`
LocalID string `json:"localId"`
Message string `json:"message"`
}
type batchDeleteAccountsResponse struct {
Errors []batchDeleteErrorInfo `json:"errors"`
Errors []*DeleteUsersErrorInfo `json:"errors"`
}

resp := batchDeleteAccountsResponse{}
_, err := c.post(ctx, "/accounts:batchDelete", payload, &resp)
if _, err := c.post(ctx, "/accounts:batchDelete", payload, &resp); err != nil {
return nil, err
}

result := DeleteUsersResult{
FailureCount: len(resp.Errors),
SuccessCount: len(uids) - len(resp.Errors),
Errors: resp.Errors,
}

for i := range resp.Errors {
result.Errors = append(result.Errors, &DeleteUsersErrorInfo{
Index: resp.Errors[i].Index,
Reason: resp.Errors[i].Message,
})
}

return &result, err
return &result, nil
}

// SessionCookie creates a new Firebase session cookie from the given ID token and expiry
Expand Down
Loading