Skip to content
10 changes: 4 additions & 6 deletions src/databricks/sql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,12 +248,6 @@ def read(self) -> Optional[OAuthToken]:
self.lz4_compression = kwargs.get("enable_query_result_lz4_compression", True)
self.use_cloud_fetch = kwargs.get("use_cloud_fetch", True)
self._cursors = [] # type: List[Cursor]

self.server_telemetry_enabled = True
self.client_telemetry_enabled = kwargs.get("enable_telemetry", False)
self.telemetry_enabled = (
self.client_telemetry_enabled and self.server_telemetry_enabled
)
self.telemetry_batch_size = kwargs.get(
"telemetry_batch_size", TelemetryClientFactory.DEFAULT_BATCH_SIZE
)
Expand Down Expand Up @@ -288,6 +282,10 @@ def read(self) -> Optional[OAuthToken]:
)
self.staging_allowed_local_path = kwargs.get("staging_allowed_local_path", None)

self.force_enable_telemetry = kwargs.get("force_enable_telemetry", False)
self.enable_telemetry = kwargs.get("enable_telemetry", False)
self.telemetry_enabled = TelemetryHelper.is_telemetry_enabled(self)

TelemetryClientFactory.initialize_telemetry_client(
telemetry_enabled=self.telemetry_enabled,
session_id_hex=self.get_session_id_hex(),
Expand Down
176 changes: 176 additions & 0 deletions src/databricks/sql/common/feature_flag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
import threading
import time
import requests
from dataclasses import dataclass, field
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, Optional, List, Any, TYPE_CHECKING

if TYPE_CHECKING:
from databricks.sql.client import Connection


@dataclass
class FeatureFlagEntry:
"""Represents a single feature flag from the server response."""

name: str
value: str


@dataclass
class FeatureFlagsResponse:
"""Represents the full JSON response from the feature flag endpoint."""

flags: List[FeatureFlagEntry] = field(default_factory=list)
ttl_seconds: Optional[int] = None

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "FeatureFlagsResponse":
"""Factory method to create an instance from a dictionary (parsed JSON)."""
flags_data = data.get("flags", [])
flags_list = [FeatureFlagEntry(**flag) for flag in flags_data]
return cls(flags=flags_list, ttl_seconds=data.get("ttl_seconds"))


# --- Constants ---
FEATURE_FLAGS_ENDPOINT_SUFFIX_FORMAT = (
"/api/2.0/connector-service/feature-flags/PYTHON/{}"
)
DEFAULT_TTL_SECONDS = 900 # 15 minutes
REFRESH_BEFORE_EXPIRY_SECONDS = 10 # Start proactive refresh 10s before expiry


class FeatureFlagsContext:
"""
Manages fetching and caching of server-side feature flags for a connection.

1. The very first check for any flag is a synchronous, BLOCKING operation.
2. Subsequent refreshes (triggered near TTL expiry) are done asynchronously
in the background, returning stale data until the refresh completes.
"""

def __init__(self, connection: "Connection", executor: ThreadPoolExecutor):
from databricks.sql import __version__

self._connection = connection
self._executor = executor # Used for ASYNCHRONOUS refreshes
self._lock = threading.RLock()

# Cache state: `None` indicates the cache has never been loaded.
self._flags: Optional[Dict[str, str]] = None
self._ttl_seconds: int = DEFAULT_TTL_SECONDS
self._last_refresh_time: float = 0

endpoint_suffix = FEATURE_FLAGS_ENDPOINT_SUFFIX_FORMAT.format(__version__)
self._feature_flag_endpoint = (
f"https://{self._connection.session.host}{endpoint_suffix}"
)

def _is_refresh_needed(self) -> bool:
"""Checks if the cache is due for a proactive background refresh."""
if self._flags is None:
return False # Not eligible for refresh until loaded once.

refresh_threshold = self._last_refresh_time + (
self._ttl_seconds - REFRESH_BEFORE_EXPIRY_SECONDS
)
return time.monotonic() > refresh_threshold

def get_flag_value(self, name: str, default_value: Any) -> Any:
"""
Checks if a feature is enabled.
- BLOCKS on the first call until flags are fetched.
- Returns cached values on subsequent calls, triggering non-blocking refreshes if needed.
"""
with self._lock:
# If cache has never been loaded, perform a synchronous, blocking fetch.
if self._flags is None:
self._refresh_flags()

# If a proactive background refresh is needed, start one. This is non-blocking.
elif self._is_refresh_needed():
# We don't check for an in-flight refresh; the executor queues the task, which is safe.
self._executor.submit(self._refresh_flags)

assert self._flags is not None

# Now, return the value from the populated cache.
return self._flags.get(name, default_value)

def _refresh_flags(self):
"""Performs a synchronous network request to fetch and update flags."""
headers = {}
try:
# Authenticate the request
self._connection.session.auth_provider.add_headers(headers)
headers["User-Agent"] = self._connection.session.useragent_header

response = requests.get(
self._feature_flag_endpoint, headers=headers, timeout=30
)

if response.status_code == 200:
ff_response = FeatureFlagsResponse.from_dict(response.json())
self._update_cache_from_response(ff_response)
else:
# On failure, initialize with an empty dictionary to prevent re-blocking.
if self._flags is None:
self._flags = {}

except Exception as e:
# On exception, initialize with an empty dictionary to prevent re-blocking.
if self._flags is None:
self._flags = {}

def _update_cache_from_response(self, ff_response: FeatureFlagsResponse):
"""Atomically updates the internal cache state from a successful server response."""
with self._lock:
self._flags = {flag.name: flag.value for flag in ff_response.flags}
if ff_response.ttl_seconds is not None and ff_response.ttl_seconds > 0:
self._ttl_seconds = ff_response.ttl_seconds
self._last_refresh_time = time.monotonic()


class FeatureFlagsContextFactory:
"""
Manages a singleton instance of FeatureFlagsContext per connection session.
Also manages a shared ThreadPoolExecutor for all background refresh operations.
"""

_context_map: Dict[str, FeatureFlagsContext] = {}
_executor: Optional[ThreadPoolExecutor] = None
_lock = threading.Lock()

@classmethod
def _initialize(cls):
"""Initializes the shared executor for async refreshes if it doesn't exist."""
if cls._executor is None:
cls._executor = ThreadPoolExecutor(
max_workers=3, thread_name_prefix="feature-flag-refresher"
)

@classmethod
def get_instance(cls, connection: "Connection") -> FeatureFlagsContext:
"""Gets or creates a FeatureFlagsContext for the given connection."""
with cls._lock:
cls._initialize()
assert cls._executor is not None

# Use the unique session ID as the key
key = connection.get_session_id_hex()
if key not in cls._context_map:
cls._context_map[key] = FeatureFlagsContext(connection, cls._executor)
return cls._context_map[key]

@classmethod
def remove_instance(cls, connection: "Connection"):
"""Removes the context for a given connection and shuts down the executor if no clients remain."""
with cls._lock:
key = connection.get_session_id_hex()
if key in cls._context_map:
cls._context_map.pop(key, None)

# If this was the last active context, clean up the thread pool.
if not cls._context_map and cls._executor is not None:
cls._executor.shutdown(wait=False)
cls._executor = None
21 changes: 20 additions & 1 deletion src/databricks/sql/telemetry/telemetry_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import time
import logging
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, Optional
from typing import Dict, Optional, TYPE_CHECKING
from databricks.sql.common.http import TelemetryHttpClient
from databricks.sql.telemetry.models.event import (
TelemetryEvent,
Expand Down Expand Up @@ -36,6 +36,10 @@
import uuid
import locale
from databricks.sql.telemetry.utils import BaseTelemetryClient
from databricks.sql.common.feature_flag import FeatureFlagsContextFactory

if TYPE_CHECKING:
from databricks.sql.client import Connection

logger = logging.getLogger(__name__)

Expand All @@ -44,6 +48,7 @@ class TelemetryHelper:
"""Helper class for getting telemetry related information."""

_DRIVER_SYSTEM_CONFIGURATION = None
TELEMETRY_FEATURE_FLAG_NAME = "databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForPythonDriver"

@classmethod
def get_driver_system_configuration(cls) -> DriverSystemConfiguration:
Expand Down Expand Up @@ -98,6 +103,20 @@ def get_auth_flow(auth_provider):
else:
return None

@staticmethod
def is_telemetry_enabled(connection: "Connection") -> bool:
if connection.force_enable_telemetry:
return True

if connection.enable_telemetry:
context = FeatureFlagsContextFactory.get_instance(connection)
flag_value = context.get_flag_value(
TelemetryHelper.TELEMETRY_FEATURE_FLAG_NAME, default_value=False
)
return str(flag_value).lower() == "true"
else:
return False


class NoopTelemetryClient(BaseTelemetryClient):
"""
Expand Down
2 changes: 1 addition & 1 deletion tests/e2e/test_concurrent_telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def execute_query_worker(thread_id):

time.sleep(random.uniform(0, 0.05))

with self.connection(extra_params={"enable_telemetry": True}) as conn:
with self.connection(extra_params={"force_enable_telemetry": True}) as conn:
# Capture the session ID from the connection before executing the query
session_id_hex = conn.get_session_id_hex()
with capture_lock:
Expand Down
91 changes: 88 additions & 3 deletions tests/unit/test_telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
NoopTelemetryClient,
TelemetryClientFactory,
TelemetryHelper,
BaseTelemetryClient,
)
from databricks.sql.telemetry.models.enums import AuthMech, AuthFlow
from databricks.sql.auth.authenticators import (
AccessTokenAuthProvider,
DatabricksOAuthProvider,
ExternalAuthProvider,
)
from databricks import sql


@pytest.fixture
Expand Down Expand Up @@ -311,8 +311,6 @@ def test_connection_failure_sends_correct_telemetry_payload(
mock_session.side_effect = Exception(error_message)

try:
from databricks import sql

sql.connect(server_hostname="test-host", http_path="/test-path")
except Exception as e:
assert str(e) == error_message
Expand All @@ -321,3 +319,90 @@ def test_connection_failure_sends_correct_telemetry_payload(
call_arguments = mock_export_failure_log.call_args
assert call_arguments[0][0] == "Exception"
assert call_arguments[0][1] == error_message


@patch("databricks.sql.client.Session")
class TestTelemetryFeatureFlag:
"""Tests the interaction between the telemetry feature flag and connection parameters."""

def _mock_ff_response(self, mock_requests_get, enabled: bool):
"""Helper to configure the mock response for the feature flag endpoint."""
mock_response = MagicMock()
mock_response.status_code = 200
payload = {
"flags": [
{
"name": "databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForPythonDriver",
"value": str(enabled).lower(),
}
],
"ttl_seconds": 3600,
}
mock_response.json.return_value = payload
mock_requests_get.return_value = mock_response

@patch("databricks.sql.common.feature_flag.requests.get")
def test_telemetry_enabled_when_flag_is_true(
self, mock_requests_get, MockSession
):
"""Telemetry should be ON when enable_telemetry=True and server flag is 'true'."""
self._mock_ff_response(mock_requests_get, enabled=True)
mock_session_instance = MockSession.return_value
mock_session_instance.guid_hex = "test-session-ff-true"
mock_session_instance.auth_provider = AccessTokenAuthProvider("token")

conn = sql.client.Connection(
server_hostname="test",
http_path="test",
access_token="test",
enable_telemetry=True,
)

assert conn.telemetry_enabled is True
mock_requests_get.assert_called_once()
client = TelemetryClientFactory.get_telemetry_client("test-session-ff-true")
assert isinstance(client, TelemetryClient)

@patch("databricks.sql.common.feature_flag.requests.get")
def test_telemetry_disabled_when_flag_is_false(
self, mock_requests_get, MockSession
):
"""Telemetry should be OFF when enable_telemetry=True but server flag is 'false'."""
self._mock_ff_response(mock_requests_get, enabled=False)
mock_session_instance = MockSession.return_value
mock_session_instance.guid_hex = "test-session-ff-false"
mock_session_instance.auth_provider = AccessTokenAuthProvider("token")

conn = sql.client.Connection(
server_hostname="test",
http_path="test",
access_token="test",
enable_telemetry=True,
)

assert conn.telemetry_enabled is False
mock_requests_get.assert_called_once()
client = TelemetryClientFactory.get_telemetry_client("test-session-ff-false")
assert isinstance(client, NoopTelemetryClient)

@patch("databricks.sql.common.feature_flag.requests.get")
def test_telemetry_disabled_when_flag_request_fails(
self, mock_requests_get, MockSession
):
"""Telemetry should default to OFF if the feature flag network request fails."""
mock_requests_get.side_effect = Exception("Network is down")
mock_session_instance = MockSession.return_value
mock_session_instance.guid_hex = "test-session-ff-fail"
mock_session_instance.auth_provider = AccessTokenAuthProvider("token")

conn = sql.client.Connection(
server_hostname="test",
http_path="test",
access_token="test",
enable_telemetry=True,
)

assert conn.telemetry_enabled is False
mock_requests_get.assert_called_once()
client = TelemetryClientFactory.get_telemetry_client("test-session-ff-fail")
assert isinstance(client, NoopTelemetryClient)
Loading