Skip to content

Commit 21a60bc

Browse files
committed
DataConnectAuth.kt: add authUids to GetAuthTokenResult
This is a partial cherry-pick of #7485
1 parent a1e7e86 commit 21a60bc

File tree

3 files changed

+100
-7
lines changed

3 files changed

+100
-7
lines changed

firebase-dataconnect/src/main/kotlin/com/google/firebase/dataconnect/core/DataConnectAuth.kt

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,31 @@ internal class DataConnectAuth(
4949
provider.removeIdTokenListener(idTokenListener)
5050

5151
override suspend fun getToken(provider: InternalAuthProvider, forceRefresh: Boolean) =
52-
provider.getAccessToken(forceRefresh).await().let { GetAuthTokenResult(it.token) }
52+
provider.getAccessToken(forceRefresh).await().let {
53+
GetAuthTokenResult(it.token, it.getAuthUids())
54+
}
5355

54-
data class GetAuthTokenResult(override val token: String?) : GetTokenResult
56+
data class GetAuthTokenResult(override val token: String?, val authUids: Set<String>) :
57+
GetTokenResult
5558

5659
private class IdTokenListenerImpl(private val logger: Logger) : IdTokenListener {
5760
override fun onIdTokenChanged(tokenResult: InternalTokenResult) {
5861
logger.debug { "onIdTokenChanged(token=${tokenResult.token?.toScrubbedAccessToken()})" }
5962
}
6063
}
64+
65+
private companion object {
66+
67+
val authUidClaimNames = listOf("user_id", "sub")
68+
69+
fun com.google.firebase.auth.GetTokenResult.getAuthUids(): Set<String> = buildSet {
70+
authUidClaimNames.forEach { claimName ->
71+
claims[claimName]?.let { claimValue ->
72+
if (claimValue is String) {
73+
add(claimValue)
74+
}
75+
}
76+
}
77+
}
78+
}
6179
}

firebase-dataconnect/src/test/kotlin/com/google/firebase/dataconnect/core/DataConnectAuthUnitTest.kt

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ import com.google.firebase.dataconnect.testutil.UnavailableDeferred
3232
import com.google.firebase.dataconnect.testutil.newBackgroundScopeThatAdvancesLikeForeground
3333
import com.google.firebase.dataconnect.testutil.newMockLogger
3434
import com.google.firebase.dataconnect.testutil.property.arbitrary.dataConnect
35+
import com.google.firebase.dataconnect.testutil.property.arbitrary.distinctPair
3536
import com.google.firebase.dataconnect.testutil.shouldContainWithNonAbuttingText
3637
import com.google.firebase.dataconnect.testutil.shouldContainWithNonAbuttingTextIgnoringCase
3738
import com.google.firebase.dataconnect.testutil.shouldHaveLoggedAtLeastOneMessageContaining
@@ -46,15 +47,19 @@ import io.kotest.assertions.nondeterministic.eventually
4647
import io.kotest.assertions.nondeterministic.eventuallyConfig
4748
import io.kotest.assertions.throwables.shouldThrow
4849
import io.kotest.assertions.withClue
50+
import io.kotest.matchers.collections.shouldBeEmpty
4951
import io.kotest.matchers.collections.shouldContain
5052
import io.kotest.matchers.collections.shouldContainExactly
53+
import io.kotest.matchers.collections.shouldContainExactlyInAnyOrder
5154
import io.kotest.matchers.nulls.shouldBeNull
5255
import io.kotest.matchers.nulls.shouldNotBeNull
5356
import io.kotest.matchers.shouldBe
5457
import io.kotest.matchers.types.shouldBeSameInstanceAs
5558
import io.kotest.property.Arb
5659
import io.kotest.property.RandomSource
60+
import io.kotest.property.arbitrary.map
5761
import io.kotest.property.arbitrary.next
62+
import io.kotest.property.arbs.products.brand
5863
import io.mockk.coEvery
5964
import io.mockk.confirmVerified
6065
import io.mockk.every
@@ -311,6 +316,74 @@ class DataConnectAuthUnitTest {
311316
mockLogger.shouldNotHaveLoggedAnyMessagesContaining(accessToken)
312317
}
313318

319+
@Test
320+
fun `getToken() should populate authUids from user_id claim`() = runTest {
321+
val dataConnectAuth = newDataConnectAuth()
322+
dataConnectAuth.initialize()
323+
advanceUntilIdle()
324+
val uid = Arb.brand().map { it.value }.next(rs)
325+
coEvery { mockInternalAuthProvider.getAccessToken(any()) } returns
326+
taskForToken(accessToken, mapOf("user_id" to uid))
327+
328+
val result = dataConnectAuth.getToken(requestId)
329+
330+
result.shouldNotBeNull().authUids.shouldContainExactly(uid)
331+
}
332+
333+
@Test
334+
fun `getToken() should populate authUids from sub claim`() = runTest {
335+
val dataConnectAuth = newDataConnectAuth()
336+
dataConnectAuth.initialize()
337+
advanceUntilIdle()
338+
val uid = Arb.brand().map { it.value }.next(rs)
339+
coEvery { mockInternalAuthProvider.getAccessToken(any()) } returns
340+
taskForToken(accessToken, mapOf("sub" to uid))
341+
342+
val result = dataConnectAuth.getToken(requestId)
343+
344+
result.shouldNotBeNull().authUids.shouldContainExactly(uid)
345+
}
346+
347+
@Test
348+
fun `getToken() should populate authUids from user_id and sub claims`() = runTest {
349+
val dataConnectAuth = newDataConnectAuth()
350+
dataConnectAuth.initialize()
351+
advanceUntilIdle()
352+
val (uid1, uid2) = Arb.brand().map { it.value }.distinctPair().next(rs)
353+
coEvery { mockInternalAuthProvider.getAccessToken(any()) } returns
354+
taskForToken(accessToken, mapOf("user_id" to uid1, "sub" to uid2))
355+
356+
val result = dataConnectAuth.getToken(requestId)
357+
358+
result.shouldNotBeNull().authUids.shouldContainExactlyInAnyOrder(uid1, uid2)
359+
}
360+
361+
@Test
362+
fun `getToken() should populate empty authUids if claims are missing`() = runTest {
363+
val dataConnectAuth = newDataConnectAuth()
364+
dataConnectAuth.initialize()
365+
advanceUntilIdle()
366+
coEvery { mockInternalAuthProvider.getAccessToken(any()) } returns
367+
taskForToken(accessToken, emptyMap())
368+
369+
val result = dataConnectAuth.getToken(requestId)
370+
371+
result.shouldNotBeNull().authUids.shouldBeEmpty()
372+
}
373+
374+
@Test
375+
fun `getToken() should ignore non-string uid claims`() = runTest {
376+
val dataConnectAuth = newDataConnectAuth()
377+
dataConnectAuth.initialize()
378+
advanceUntilIdle()
379+
coEvery { mockInternalAuthProvider.getAccessToken(any()) } returns
380+
taskForToken(accessToken, mapOf("user_id" to 123, "sub" to true))
381+
382+
val result = dataConnectAuth.getToken(requestId)
383+
384+
result.shouldNotBeNull().authUids shouldBe emptySet()
385+
}
386+
314387
@Test
315388
fun `getToken() should return re-throw the exception from the task returned from FirebaseAuth`() =
316389
runTest {
@@ -613,7 +686,7 @@ class DataConnectAuthUnitTest {
613686
interval = 100.milliseconds
614687
}
615688

616-
fun taskForToken(token: String?): Task<GetTokenResult> =
617-
Tasks.forResult(mockk(relaxed = true) { every { getToken() } returns token })
689+
fun taskForToken(token: String?, claims: Map<String, Any> = emptyMap()): Task<GetTokenResult> =
690+
Tasks.forResult(GetTokenResult(token, claims))
618691
}
619692
}

firebase-dataconnect/src/test/kotlin/com/google/firebase/dataconnect/testutil/property/arbitrary/arbs.kt

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ import io.kotest.property.arbitrary.int
5151
import io.kotest.property.arbitrary.list
5252
import io.kotest.property.arbitrary.map
5353
import io.kotest.property.arbitrary.orNull
54+
import io.kotest.property.arbitrary.set
5455
import io.kotest.property.arbitrary.string
5556
import io.mockk.coEvery
5657
import io.mockk.mockk
@@ -333,9 +334,10 @@ internal inline fun <Data, reified Variables> DataConnectArb.operationRefConstru
333334
}
334335

335336
internal fun DataConnectArb.authTokenResult(
336-
accessToken: Arb<String> = accessToken()
337-
): Arb<GetAuthTokenResult> = accessToken.map { GetAuthTokenResult(it) }
337+
accessToken: Arb<String?> = accessToken().orNull(nullProbability = 0.33),
338+
authUids: Arb<Set<String>> = Arb.set(string(0..10, Codepoint.alphanumeric()), 0..10),
339+
): Arb<GetAuthTokenResult> = Arb.bind(accessToken, authUids, ::GetAuthTokenResult)
338340

339341
internal fun DataConnectArb.appCheckTokenResult(
340342
accessToken: Arb<String> = accessToken()
341-
): Arb<GetAppCheckTokenResult> = accessToken.map { GetAppCheckTokenResult(it) }
343+
): Arb<GetAppCheckTokenResult> = accessToken.map(::GetAppCheckTokenResult)

0 commit comments

Comments
 (0)