Skip to content

Commit 964ddd2

Browse files
committed
added support for persistence
Signed-off-by: Moe Derakhshani <[email protected]>
1 parent 9642692 commit 964ddd2

File tree

7 files changed

+113
-17
lines changed

7 files changed

+113
-17
lines changed

src/databricks/sql/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,6 @@ def TimestampFromTicks(ticks):
4444
return Timestamp(*time.localtime(ticks)[:6])
4545

4646

47-
def connect(server_hostname, http_path, **kwargs):
47+
def connect(server_hostname, http_path, experimental_oauth_persistence=None, **kwargs):
4848
from .client import Connection
49-
return Connection(server_hostname, http_path, **kwargs)
49+
return Connection(server_hostname, http_path, experimental_oauth_persistence, **kwargs)

src/databricks/sql/auth/auth.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from enum import Enum
2525
from databricks.sql.auth.authenticators import CredentialsProvider, \
2626
AccessTokenAuthProvider, BasicAuthProvider, DatabricksOAuthProvider
27+
from databricks.sql.experimental.oauth_persistence import OAuthPersistence
2728

2829

2930
class AuthType(Enum):
@@ -42,7 +43,9 @@ def __init__(self,
4243
oauth_scopes: List[str] = None,
4344
oauth_client_id: str = None,
4445
use_cert_as_auth: str = None,
45-
tls_client_cert_file: str = None):
46+
tls_client_cert_file: str = None,
47+
oauth_persistence=None
48+
):
4649
self.hostname = hostname
4750
self.username = username
4851
self.password = password
@@ -52,11 +55,12 @@ def __init__(self,
5255
self.oauth_client_id = oauth_client_id
5356
self.use_cert_as_auth = use_cert_as_auth
5457
self.tls_client_cert_file = tls_client_cert_file
58+
self.oauth_persistence = oauth_persistence
5559

5660

5761
def get_auth_provider(cfg: ClientContext):
5862
if cfg.auth_type == AuthType.DATABRICKS_OAUTH.value:
59-
return DatabricksOAuthProvider(cfg.hostname, cfg.oauth_client_id, cfg.oauth_scopes)
63+
return DatabricksOAuthProvider(cfg.hostname, cfg.oauth_persistence, cfg.oauth_client_id, cfg.oauth_scopes)
6064
elif cfg.access_token is not None:
6165
return AccessTokenAuthProvider(cfg.access_token)
6266
elif cfg.username is not None and cfg.password is not None:
@@ -73,7 +77,7 @@ def get_auth_provider(cfg: ClientContext):
7377
OAUTH_CLIENT_ID = "databricks-cli"
7478

7579

76-
def get_python_sql_connector_auth_provider(hostname: str, **kwargs):
80+
def get_python_sql_connector_auth_provider(hostname: str, oauth_persistence: OAuthPersistence, **kwargs):
7781
cfg = ClientContext(hostname=hostname,
7882
auth_type=kwargs.get("auth_type"),
7983
access_token=kwargs.get("access_token"),
@@ -82,7 +86,8 @@ def get_python_sql_connector_auth_provider(hostname: str, **kwargs):
8286
use_cert_as_auth=kwargs.get("_use_cert_as_auth"),
8387
tls_client_cert_file=kwargs.get("_tls_client_cert_file"),
8488
oauth_scopes=OAUTH_SCOPES,
85-
oauth_client_id=OAUTH_CLIENT_ID)
89+
oauth_client_id=OAUTH_CLIENT_ID,
90+
oauth_persistence=oauth_persistence)
8691
return get_auth_provider(cfg)
8792

8893

src/databricks/sql/auth/authenticators.py

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727

2828
# Private API: this is an evolving interface and it will change in the future.
2929
# Please must not depend on it in your applications.
30+
from databricks.sql.experimental.oauth_persistence import OAuthToken
31+
32+
3033
class CredentialsProvider:
3134
def add_headers(self, request_headers):
3235
pass
@@ -59,18 +62,17 @@ def add_headers(self, request_headers):
5962
# Please must not depend on it in your applications.
6063
class DatabricksOAuthProvider(CredentialsProvider):
6164
SCOPE_DELIM = ' '
62-
# TODO: moderakh the refresh_token is only kept in memory. not saved on disk
63-
# hence if application restarts the user may need to re-authenticate
64-
# I will add support for this outside of the scope of current PR.
65-
def __init__(self, hostname, client_id, scopes):
65+
66+
def __init__(self, hostname, oauth_persistence, client_id, scopes):
6667
self._hostname = self._normalize_host_name(hostname=hostname)
6768
self._scopes_as_str = DatabricksOAuthProvider.SCOPE_DELIM.join(scopes)
68-
access_token, refresh_token = get_tokens(hostname=self._hostname, client_id=client_id, scope=self._scopes_as_str)
69-
self._access_token = access_token
70-
self._refresh_token = refresh_token
69+
self._oauth_persistence = oauth_persistence
70+
self._client_id = client_id
71+
self._get_tokens()
7172

7273
def add_headers(self, request_headers):
7374
check_and_refresh_access_token(hostname=self._hostname,
75+
client_id=self._client_id,
7476
access_token=self._access_token,
7577
refresh_token=self._refresh_token)
7678
request_headers['Authorization'] = f"Bearer {self._access_token}"
@@ -80,3 +82,35 @@ def _normalize_host_name(hostname):
8082
maybe_scheme = "https://" if not hostname.startswith("https://") else ""
8183
maybe_trailing_slash = "/" if not hostname.endswith("/") else ""
8284
return f"{maybe_scheme}{hostname}{maybe_trailing_slash}"
85+
86+
def _get_tokens(self):
87+
if self._oauth_persistence:
88+
token = self._oauth_persistence.read()
89+
if token:
90+
self._access_token = token.get_access_token()
91+
self._refresh_token = token.get_refresh_token()
92+
self._update_token_if_expired()
93+
else:
94+
(access_token, refresh_token) = get_tokens(hostname=self._hostname,
95+
client_id=self._client_id,
96+
scope=self._scopes_as_str)
97+
self._access_token = access_token
98+
self._refresh_token = refresh_token
99+
self._oauth_persistence.persist(OAuthToken(access_token, refresh_token))
100+
101+
def _update_token_if_expired(self):
102+
(fresh_access_token, fresh_refresh_token, is_refreshed) = check_and_refresh_access_token(
103+
hostname=self._hostname,
104+
client_id=self._client_id,
105+
access_token=self._access_token,
106+
refresh_token=self._refresh_token)
107+
108+
if not is_refreshed:
109+
return
110+
else:
111+
self._access_token = fresh_access_token
112+
self._refresh_token = fresh_refresh_token
113+
114+
if self._oauth_persistence:
115+
token = OAuthToken(self._access_token, self._refresh_token)
116+
self._oauth_persistence.persist(token)

src/databricks/sql/auth/oauth.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def get_tokens_from_response(oauth_response):
220220
return access_token, refresh_token
221221

222222

223-
def check_and_refresh_access_token(hostname, access_token, refresh_token):
223+
def check_and_refresh_access_token(hostname, client_id, access_token, refresh_token):
224224
now = datetime.now(tz=UTC)
225225
# If we can't decode an expiration time, this will be expired by default.
226226
expiration_time = now
@@ -246,7 +246,7 @@ def check_and_refresh_access_token(hostname, access_token, refresh_token):
246246

247247
# Try to refresh using the refresh token
248248
logger.debug(f"Attempting to refresh OAuth access token that expired on {expiration_time}")
249-
oauth_response = send_refresh_token_request(hostname, refresh_token)
249+
oauth_response = send_refresh_token_request(hostname, client_id, refresh_token)
250250
fresh_access_token, fresh_refresh_token = get_tokens_from_response(oauth_response)
251251
return fresh_access_token, fresh_refresh_token, True
252252

src/databricks/sql/client.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from databricks.sql.utils import ExecuteResponse, ParamEscaper
1111
from databricks.sql.types import Row
1212
from databricks.sql.auth.auth import get_python_sql_connector_auth_provider
13-
13+
from databricks.sql.experimental.oauth_persistence import OAuthPersistence
1414
logger = logging.getLogger(__name__)
1515

1616
DEFAULT_RESULT_BUFFER_SIZE_BYTES = 10485760
@@ -22,6 +22,7 @@ def __init__(
2222
self,
2323
server_hostname: str,
2424
http_path: str,
25+
oauth_persistence: OAuthPersistence = None,
2526
http_headers: Optional[List[Tuple[str, str]]] = None,
2627
session_configuration: Dict[str, Any] = None,
2728
catalog: Optional[str] = None,
@@ -85,7 +86,7 @@ def __init__(
8586
self.port = kwargs.get("_port", 443)
8687
self.disable_pandas = kwargs.get("_disable_pandas", False)
8788

88-
auth_provider = get_python_sql_connector_auth_provider(server_hostname, **kwargs)
89+
auth_provider = get_python_sql_connector_auth_provider(server_hostname, oauth_persistence, **kwargs)
8990

9091
if not kwargs.get("_user_agent_entry"):
9192
useragent_header = "{}/{}".format(USER_AGENT_NAME, __version__)

src/databricks/sql/experimental/__init__.py

Whitespace-only changes.
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import logging
2+
import json
3+
logger = logging.getLogger(__name__)
4+
5+
6+
class OAuthToken:
7+
def __init__(self, access_token, refresh_token):
8+
self._access_token = access_token
9+
self._refresh_token = refresh_token
10+
11+
def get_access_token(self) -> str:
12+
return self._access_token
13+
14+
def get_refresh_token(self) -> str:
15+
return self._refresh_token
16+
17+
18+
class OAuthPersistence:
19+
def persist(self, oauth_token: OAuthToken):
20+
pass
21+
22+
def read(self) -> OAuthToken:
23+
pass
24+
25+
26+
# Note this is only intended to be used for development
27+
class DevOnlyFilePersistence(OAuthPersistence):
28+
29+
def __init__(self, file_path):
30+
self._file_path = file_path
31+
32+
def persist(self, token: OAuthToken):
33+
logger.info(f"persisting token in {self._file_path}")
34+
35+
# Data to be written
36+
dictionary = {
37+
"refresh_token": token.get_refresh_token(),
38+
"access_token": token.get_access_token()
39+
}
40+
41+
# Serializing json
42+
json_object = json.dumps(dictionary, indent=4)
43+
44+
with open(self._file_path, "w") as outfile:
45+
outfile.write(json_object)
46+
47+
def read(self) -> OAuthToken:
48+
# TODO: validate the
49+
try:
50+
with open(self._file_path, "r") as infile:
51+
json_as_string = infile.read()
52+
53+
token_as_json = json.loads(json_as_string)
54+
return OAuthToken(token_as_json['access_token'], token_as_json['refresh_token'])
55+
except Exception as e:
56+
return None

0 commit comments

Comments
 (0)