10
10
11
11
from descope .common import (
12
12
DEFAULT_BASE_URI ,
13
+ DEFAULT_FETCH_PUBLIC_KEY_URI ,
13
14
EMAIL_REGEX ,
15
+ GET_KEYS_PATH ,
14
16
PHONE_REGEX ,
15
17
SIGNIN_OTP_PATH ,
16
18
SIGNUP_OTP_PATH ,
22
24
23
25
24
26
class AuthClient :
25
- def __init__ (self , project_id : str , public_key : str ):
26
-
27
+ def __init__ (self , project_id : str , public_key : str = None ):
27
28
# validate project id
28
29
if project_id is None or project_id == "" :
29
30
# try get the project_id from env
@@ -37,14 +38,15 @@ def __init__(self, project_id: str, public_key: str):
37
38
self .project_id = project_id
38
39
39
40
if public_key is None or public_key == "" :
40
- public_key = os .getenv ("DESCOPE_PUBLIC_KEY" , "" )
41
- if public_key == "" :
42
- raise AuthException (
43
- 500 ,
44
- "Init failure" ,
45
- "Failed to init AuthClient object, public key cannot be found" ,
46
- )
41
+ public_key = os .getenv ("DESCOPE_PUBLIC_KEY" , None )
42
+
43
+ if public_key is None :
44
+ self .public_key = None # public key will be fetch later (on demand)
45
+ else :
46
+ self .public_key = self ._validate_and_load_public_key (public_key )
47
47
48
+ @staticmethod
49
+ def _validate_and_load_public_key (public_key ) -> jwt .PyJWK :
48
50
if isinstance (public_key , str ):
49
51
try :
50
52
public_key = json .loads (public_key )
@@ -64,7 +66,7 @@ def __init__(self, project_id: str, public_key: str):
64
66
65
67
try :
66
68
# Load and validate public key
67
- self . public_key = jwt .PyJWK (public_key )
69
+ return jwt .PyJWK (public_key )
68
70
except jwt .InvalidKeyError as e :
69
71
raise AuthException (
70
72
500 ,
@@ -78,6 +80,40 @@ def __init__(self, project_id: str, public_key: str):
78
80
f"Failed to init AuthClient object, failed to load public key { e } " ,
79
81
)
80
82
83
+ def _fetch_public_key (self , kid : str ) -> None :
84
+ response = requests .get (
85
+ f"{ DEFAULT_FETCH_PUBLIC_KEY_URI } { GET_KEYS_PATH } /{ self .project_id } " ,
86
+ headers = self ._get_default_headers (),
87
+ )
88
+
89
+ if not response .ok :
90
+ raise AuthException (
91
+ 401 , "public key fetching failed" , f"err: { response .reason } "
92
+ )
93
+
94
+ jwks_data = response .text
95
+ try :
96
+ jwkeys = json .loads (jwks_data )
97
+ except Exception as e :
98
+ raise AuthException (
99
+ 401 , "public key fetching failed" , f"Failed to load jwks { e } "
100
+ )
101
+
102
+ founded_key = None
103
+ for key in jwkeys :
104
+ if key ["kid" ] == kid :
105
+ founded_key = key
106
+ break
107
+
108
+ if founded_key :
109
+ self .public_key = AuthClient ._validate_and_load_public_key (founded_key )
110
+ else :
111
+ raise AuthException (
112
+ 401 ,
113
+ "public key validation failed" ,
114
+ "Failed to validate public key, public key not found" ,
115
+ )
116
+
81
117
@staticmethod
82
118
def _verify_delivery_method (method : DeliveryMethod , identifier : str ) -> bool :
83
119
if identifier == "" or identifier is None :
@@ -219,23 +255,29 @@ def validate_session_request(self, signed_token):
219
255
"""
220
256
DOC
221
257
"""
258
+
222
259
try :
223
260
unverified_header = jwt .get_unverified_header (signed_token )
224
261
except Exception as e :
225
262
raise AuthException (
226
263
401 ,
227
264
"token validation failure" ,
228
- f"Failed to get unverified token header, { e } " ,
265
+ f"Failed to parse token header, { e } " ,
229
266
)
230
- token_type = unverified_header . get ( "typ" , None )
231
- alg = unverified_header .get ("alg " , None )
232
- if token_type is None or alg is None :
267
+
268
+ kid = unverified_header .get ("kid " , None )
269
+ if kid is None :
233
270
raise AuthException (
234
271
401 ,
235
272
"token validation failure" ,
236
- f "Token header is missing token type or algorithm, token_type= { token_type } alg= { alg } " ,
273
+ "Token header is missing kid property " ,
237
274
)
238
275
276
+ if self .public_key is None :
277
+ self ._fetch_public_key (
278
+ kid
279
+ ) # will set self.public_key or raise exception if failed
280
+
239
281
try :
240
282
jwt .decode (jwt = signed_token , key = self .public_key .key , algorithms = ["ES384" ])
241
283
except Exception as e :
0 commit comments