Skip to content
This repository was archived by the owner on Sep 8, 2025. It is now read-only.

Commit 6a90f66

Browse files
committed
fix base64 token validation and add tests for decode_jwt
1 parent 2a4bad5 commit 6a90f66

File tree

5 files changed

+245
-25
lines changed

5 files changed

+245
-25
lines changed

supabase_auth/_async/gotrue_client.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@
5858
AuthResponse,
5959
ClaimsResponse,
6060
CodeExchangeParams,
61-
DecodedJWTDict,
6261
IdentitiesResponse,
6362
JWKSet,
6463
MFAChallengeAndVerifyParams,
@@ -687,7 +686,7 @@ async def set_session(self, access_token: str, refresh_token: str) -> AuthRespon
687686
has_expired = True
688687
session: Optional[Session] = None
689688
if access_token and access_token.split(".")[1]:
690-
payload = self._decode_jwt(access_token)
689+
payload = decode_jwt(access_token)["payload"]
691690
exp = payload.get("exp")
692691
if exp:
693692
expires_at = int(exp)
@@ -899,7 +898,7 @@ async def _get_authenticator_assurance_level(
899898
next_level=None,
900899
current_authentication_methods=[],
901900
)
902-
payload = self._decode_jwt(session.access_token)
901+
payload = decode_jwt(session.access_token)["payload"]
903902
current_level: Optional[AuthenticatorAssuranceLevels] = None
904903
if payload.get("aal"):
905904
current_level = payload.get("aal")
@@ -1137,13 +1136,6 @@ async def _get_url_for_provider(
11371136
query = urlencode(params)
11381137
return f"{url}?{query}", params
11391138

1140-
def _decode_jwt(self, jwt: str) -> DecodedJWTDict:
1141-
"""
1142-
Decodes a JWT (without performing any validation).
1143-
"""
1144-
decoded = decode_jwt(jwt)
1145-
return decoded["payload"]
1146-
11471139
async def exchange_code_for_session(self, params: CodeExchangeParams):
11481140
code_verifier = params.get("code_verifier") or await self._storage.get_item(
11491141
f"{self._storage_key}-code-verifier"

supabase_auth/_sync/gotrue_client.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@
5858
AuthResponse,
5959
ClaimsResponse,
6060
CodeExchangeParams,
61-
DecodedJWTDict,
6261
IdentitiesResponse,
6362
JWKSet,
6463
MFAChallengeAndVerifyParams,
@@ -685,7 +684,7 @@ def set_session(self, access_token: str, refresh_token: str) -> AuthResponse:
685684
has_expired = True
686685
session: Optional[Session] = None
687686
if access_token and access_token.split(".")[1]:
688-
payload = self._decode_jwt(access_token)
687+
payload = decode_jwt(access_token)["payload"]
689688
exp = payload.get("exp")
690689
if exp:
691690
expires_at = int(exp)
@@ -895,7 +894,7 @@ def _get_authenticator_assurance_level(
895894
next_level=None,
896895
current_authentication_methods=[],
897896
)
898-
payload = self._decode_jwt(session.access_token)
897+
payload = decode_jwt(session.access_token)["payload"]
899898
current_level: Optional[AuthenticatorAssuranceLevels] = None
900899
if payload.get("aal"):
901900
current_level = payload.get("aal")
@@ -1131,13 +1130,6 @@ def _get_url_for_provider(
11311130
query = urlencode(params)
11321131
return f"{url}?{query}", params
11331132

1134-
def _decode_jwt(self, jwt: str) -> DecodedJWTDict:
1135-
"""
1136-
Decodes a JWT (without performing any validation).
1137-
"""
1138-
decoded = decode_jwt(jwt)
1139-
return decoded["payload"]
1140-
11411133
def exchange_code_for_session(self, params: CodeExchangeParams):
11421134
code_verifier = params.get("code_verifier") or self._storage.get_item(
11431135
f"{self._storage_key}-code-verifier"

supabase_auth/helpers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -229,9 +229,9 @@ def decode_jwt(token: str) -> DecodedJWT:
229229
raise AuthInvalidJwtError("Invalid JWT structure")
230230

231231
# regex check for base64url
232-
# for part in parts:
233-
# if not re.match(BASE64URL_REGEX, part):
234-
# raise AuthInvalidJwtError("JWT not in base64url format")
232+
for part in parts:
233+
if not re.match(BASE64URL_REGEX, part, re.IGNORECASE):
234+
raise AuthInvalidJwtError("JWT not in base64url format")
235235

236236
return DecodedJWT(
237237
header=JWTHeader(**loads(str_from_base64url(parts[0]))),

tests/_async/test_gotrue.py

Lines changed: 119 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
import time
22
import unittest
33

4-
from .clients import auth_client, auth_client_with_asymmetric_session
4+
import pytest
5+
from jwt import encode
6+
7+
from supabase_auth.errors import AuthInvalidJwtError, AuthSessionMissingError
8+
from supabase_auth.helpers import decode_jwt
9+
10+
from .clients import GOTRUE_JWT_SECRET, auth_client, auth_client_with_asymmetric_session
511
from .utils import mock_user_credentials
612

713

@@ -71,3 +77,115 @@ async def test_jwks_ttl_cache_behavior(mocker):
7177
finally:
7278
# Restore original time function
7379
mocker.patch("time.time", original_time)
80+
81+
82+
async def test_set_session_with_valid_tokens():
83+
client = auth_client()
84+
credentials = mock_user_credentials()
85+
86+
# First sign up to get valid tokens
87+
signup_response = await client.sign_up(
88+
{
89+
"email": credentials.get("email"),
90+
"password": credentials.get("password"),
91+
}
92+
)
93+
assert signup_response.session is not None
94+
95+
# Get the tokens from the signup response
96+
access_token = signup_response.session.access_token
97+
refresh_token = signup_response.session.refresh_token
98+
99+
# Clear the session
100+
await client._remove_session()
101+
102+
# Set the session with the tokens
103+
response = await client.set_session(access_token, refresh_token)
104+
105+
# Verify the response
106+
assert response.session is not None
107+
assert response.session.access_token == access_token
108+
assert response.session.refresh_token == refresh_token
109+
assert response.user is not None
110+
assert response.user.email == credentials.get("email")
111+
112+
113+
async def test_set_session_with_expired_token():
114+
client = auth_client()
115+
credentials = mock_user_credentials()
116+
117+
# First sign up to get valid tokens
118+
signup_response = await client.sign_up(
119+
{
120+
"email": credentials.get("email"),
121+
"password": credentials.get("password"),
122+
}
123+
)
124+
assert signup_response.session is not None
125+
126+
# Get the tokens from the signup response
127+
access_token = signup_response.session.access_token
128+
refresh_token = signup_response.session.refresh_token
129+
130+
# Clear the session
131+
await client._remove_session()
132+
133+
# Create an expired token by modifying the JWT
134+
expired_token = access_token.split(".")
135+
payload = decode_jwt(access_token)["payload"]
136+
payload["exp"] = int(time.time()) - 3600 # Set expiry to 1 hour ago
137+
expired_token[1] = encode(payload, GOTRUE_JWT_SECRET, algorithm="HS256").split(".")[
138+
1
139+
]
140+
expired_access_token = ".".join(expired_token)
141+
142+
# Set the session with the expired token
143+
response = await client.set_session(expired_access_token, refresh_token)
144+
145+
# Verify the response has a new access token (refreshed)
146+
assert response.session is not None
147+
assert response.session.access_token != expired_access_token
148+
assert response.session.refresh_token != refresh_token
149+
assert response.user is not None
150+
assert response.user.email == credentials.get("email")
151+
152+
153+
async def test_set_session_without_refresh_token():
154+
client = auth_client()
155+
credentials = mock_user_credentials()
156+
157+
# First sign up to get valid tokens
158+
signup_response = await client.sign_up(
159+
{
160+
"email": credentials.get("email"),
161+
"password": credentials.get("password"),
162+
}
163+
)
164+
assert signup_response.session is not None
165+
166+
# Get the access token from the signup response
167+
access_token = signup_response.session.access_token
168+
169+
# Clear the session
170+
await client._remove_session()
171+
172+
# Create an expired token
173+
expired_token = access_token.split(".")
174+
payload = decode_jwt(access_token)["payload"]
175+
payload["exp"] = int(time.time()) - 3600 # Set expiry to 1 hour ago
176+
expired_token[1] = encode(payload, GOTRUE_JWT_SECRET, algorithm="HS256").split(".")[
177+
1
178+
]
179+
expired_access_token = ".".join(expired_token)
180+
181+
# Try to set the session with an expired token but no refresh token
182+
with pytest.raises(AuthSessionMissingError):
183+
await client.set_session(expired_access_token, "")
184+
185+
186+
async def test_set_session_with_invalid_token():
187+
client = auth_client()
188+
189+
# Try to set the session with invalid tokens
190+
with pytest.raises(AuthInvalidJwtError):
191+
await client.set_session("invalid.token.here", "invalid_refresh_token")

tests/_sync/test_gotrue.py

Lines changed: 119 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
import time
22
import unittest
33

4-
from .clients import auth_client, auth_client_with_asymmetric_session
4+
import pytest
5+
from jwt import encode
6+
7+
from supabase_auth.errors import AuthInvalidJwtError, AuthSessionMissingError
8+
from supabase_auth.helpers import decode_jwt
9+
10+
from .clients import GOTRUE_JWT_SECRET, auth_client, auth_client_with_asymmetric_session
511
from .utils import mock_user_credentials
612

713

@@ -71,3 +77,115 @@ def test_jwks_ttl_cache_behavior(mocker):
7177
finally:
7278
# Restore original time function
7379
mocker.patch("time.time", original_time)
80+
81+
82+
def test_set_session_with_valid_tokens():
83+
client = auth_client()
84+
credentials = mock_user_credentials()
85+
86+
# First sign up to get valid tokens
87+
signup_response = client.sign_up(
88+
{
89+
"email": credentials.get("email"),
90+
"password": credentials.get("password"),
91+
}
92+
)
93+
assert signup_response.session is not None
94+
95+
# Get the tokens from the signup response
96+
access_token = signup_response.session.access_token
97+
refresh_token = signup_response.session.refresh_token
98+
99+
# Clear the session
100+
client._remove_session()
101+
102+
# Set the session with the tokens
103+
response = client.set_session(access_token, refresh_token)
104+
105+
# Verify the response
106+
assert response.session is not None
107+
assert response.session.access_token == access_token
108+
assert response.session.refresh_token == refresh_token
109+
assert response.user is not None
110+
assert response.user.email == credentials.get("email")
111+
112+
113+
def test_set_session_with_expired_token():
114+
client = auth_client()
115+
credentials = mock_user_credentials()
116+
117+
# First sign up to get valid tokens
118+
signup_response = client.sign_up(
119+
{
120+
"email": credentials.get("email"),
121+
"password": credentials.get("password"),
122+
}
123+
)
124+
assert signup_response.session is not None
125+
126+
# Get the tokens from the signup response
127+
access_token = signup_response.session.access_token
128+
refresh_token = signup_response.session.refresh_token
129+
130+
# Clear the session
131+
client._remove_session()
132+
133+
# Create an expired token by modifying the JWT
134+
expired_token = access_token.split(".")
135+
payload = decode_jwt(access_token)["payload"]
136+
payload["exp"] = int(time.time()) - 3600 # Set expiry to 1 hour ago
137+
expired_token[1] = encode(payload, GOTRUE_JWT_SECRET, algorithm="HS256").split(".")[
138+
1
139+
]
140+
expired_access_token = ".".join(expired_token)
141+
142+
# Set the session with the expired token
143+
response = client.set_session(expired_access_token, refresh_token)
144+
145+
# Verify the response has a new access token (refreshed)
146+
assert response.session is not None
147+
assert response.session.access_token != expired_access_token
148+
assert response.session.refresh_token != refresh_token
149+
assert response.user is not None
150+
assert response.user.email == credentials.get("email")
151+
152+
153+
def test_set_session_without_refresh_token():
154+
client = auth_client()
155+
credentials = mock_user_credentials()
156+
157+
# First sign up to get valid tokens
158+
signup_response = client.sign_up(
159+
{
160+
"email": credentials.get("email"),
161+
"password": credentials.get("password"),
162+
}
163+
)
164+
assert signup_response.session is not None
165+
166+
# Get the access token from the signup response
167+
access_token = signup_response.session.access_token
168+
169+
# Clear the session
170+
client._remove_session()
171+
172+
# Create an expired token
173+
expired_token = access_token.split(".")
174+
payload = decode_jwt(access_token)["payload"]
175+
payload["exp"] = int(time.time()) - 3600 # Set expiry to 1 hour ago
176+
expired_token[1] = encode(payload, GOTRUE_JWT_SECRET, algorithm="HS256").split(".")[
177+
1
178+
]
179+
expired_access_token = ".".join(expired_token)
180+
181+
# Try to set the session with an expired token but no refresh token
182+
with pytest.raises(AuthSessionMissingError):
183+
client.set_session(expired_access_token, "")
184+
185+
186+
def test_set_session_with_invalid_token():
187+
client = auth_client()
188+
189+
# Try to set the session with invalid tokens
190+
with pytest.raises(AuthInvalidJwtError):
191+
client.set_session("invalid.token.here", "invalid_refresh_token")

0 commit comments

Comments
 (0)