Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 34 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ pyarrow = [
{ version = ">=18.0.0", python = ">=3.13", optional=true }
]
pyjwt = "^2.0.0"
pybreaker = "^1.0.0"
requests-kerberos = {version = "^0.15.0", optional = true}


Expand Down
2 changes: 2 additions & 0 deletions src/databricks/sql/auth/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(
pool_connections: Optional[int] = None,
pool_maxsize: Optional[int] = None,
user_agent: Optional[str] = None,
telemetry_circuit_breaker_enabled: Optional[bool] = None,
):
self.hostname = hostname
self.access_token = access_token
Expand Down Expand Up @@ -83,6 +84,7 @@ def __init__(
self.pool_connections = pool_connections or 10
self.pool_maxsize = pool_maxsize or 20
self.user_agent = user_agent
self.telemetry_circuit_breaker_enabled = bool(telemetry_circuit_breaker_enabled)


def get_effective_azure_login_app_id(hostname) -> str:
Expand Down
35 changes: 34 additions & 1 deletion src/databricks/sql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import json
import os
import decimal
from urllib.parse import urlparse
from uuid import UUID

from databricks.sql import __version__
Expand Down Expand Up @@ -322,6 +323,20 @@ def read(self) -> Optional[OAuthToken]:
session_id_hex=self.get_session_id_hex()
)

# Determine proxy usage
use_proxy = self.http_client.using_proxy()
proxy_host_info = None
if (
use_proxy
and self.http_client.proxy_uri
and isinstance(self.http_client.proxy_uri, str)
):
parsed = urlparse(self.http_client.proxy_uri)
proxy_host_info = HostDetails(
host_url=parsed.hostname or self.http_client.proxy_uri,
port=parsed.port or 8080,
)

driver_connection_params = DriverConnectionParameters(
http_path=http_path,
mode=DatabricksClientType.SEA
Expand All @@ -331,13 +346,31 @@ def read(self) -> Optional[OAuthToken]:
auth_mech=TelemetryHelper.get_auth_mechanism(self.session.auth_provider),
auth_flow=TelemetryHelper.get_auth_flow(self.session.auth_provider),
socket_timeout=kwargs.get("_socket_timeout", None),
azure_workspace_resource_id=kwargs.get("azure_workspace_resource_id", None),
azure_tenant_id=kwargs.get("azure_tenant_id", None),
use_proxy=use_proxy,
use_system_proxy=use_proxy,
proxy_host_info=proxy_host_info,
use_cf_proxy=False, # CloudFlare proxy not yet supported in Python
cf_proxy_host_info=None, # CloudFlare proxy not yet supported in Python
non_proxy_hosts=None,
allow_self_signed_support=kwargs.get("_tls_no_verify", False),
use_system_trust_store=True, # Python uses system SSL by default
enable_arrow=pyarrow is not None,
enable_direct_results=True, # Always enabled in Python
enable_sea_hybrid_results=kwargs.get("use_hybrid_disposition", False),
http_connection_pool_size=kwargs.get("pool_maxsize", None),
rows_fetched_per_block=DEFAULT_ARRAY_SIZE,
async_poll_interval_millis=2000, # Default polling interval
support_many_parameters=True, # Native parameters supported
enable_complex_datatype_support=_use_arrow_native_complex_types,
allowed_volume_ingestion_paths=self.staging_allowed_local_path,
)

self._telemetry_client.export_initial_telemetry_log(
driver_connection_params=driver_connection_params,
user_agent=self.session.useragent_header,
)
self.staging_allowed_local_path = kwargs.get("staging_allowed_local_path", None)

def _set_use_inline_params_with_warning(self, value: Union[bool, str]):
"""Valid values are True, False, and "silent"
Expand Down
22 changes: 21 additions & 1 deletion src/databricks/sql/common/unified_http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,22 @@ def request_context(
yield response
except MaxRetryError as e:
logger.error("HTTP request failed after retries: %s", e)
raise RequestError(f"HTTP request failed: {e}")

# Try to extract HTTP status code from the MaxRetryError
http_code = None
if hasattr(e, "reason") and hasattr(e.reason, "response"):
# The reason may contain a response object with status
http_code = getattr(e.reason.response, "status", None)
elif hasattr(e, "response") and hasattr(e.response, "status"):
# Or the error itself may have a response
http_code = e.response.status

context = {}
if http_code is not None:
context["http-code"] = http_code
logger.error("HTTP request failed with status code: %d", http_code)

raise RequestError(f"HTTP request failed: {e}", context=context)
except Exception as e:
logger.error("HTTP request error: %s", e)
raise RequestError(f"HTTP request error: {e}")
Expand Down Expand Up @@ -301,6 +316,11 @@ def using_proxy(self) -> bool:
"""Check if proxy support is available (not whether it's being used for a specific request)."""
return self._proxy_pool_manager is not None

@property
def proxy_uri(self) -> Optional[str]:
"""Get the configured proxy URI, if any."""
return self._proxy_uri

def close(self):
"""Close the underlying connection pools."""
if self._direct_pool_manager:
Expand Down
7 changes: 7 additions & 0 deletions src/databricks/sql/exc.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,10 @@ class SessionAlreadyClosedError(RequestError):

class CursorAlreadyClosedError(RequestError):
"""Thrown if CancelOperation receives a code 404. ThriftBackend should gracefully proceed as this is expected."""


class TelemetryRateLimitError(Exception):
"""Raised when telemetry endpoint returns 429 or 503, indicating rate limiting or service unavailable.
This exception is used exclusively by the circuit breaker to track telemetry rate limiting events."""

pass
194 changes: 194 additions & 0 deletions src/databricks/sql/telemetry/circuit_breaker_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
"""
Circuit breaker implementation for telemetry requests.

This module provides circuit breaker functionality to prevent telemetry failures
from impacting the main SQL operations. It uses pybreaker library to implement
the circuit breaker pattern with configurable thresholds and timeouts.
"""

import logging
import threading
from typing import Dict, Optional, Any
from dataclasses import dataclass

import pybreaker
from pybreaker import CircuitBreaker, CircuitBreakerError, CircuitBreakerListener

from databricks.sql.exc import TelemetryRateLimitError

logger = logging.getLogger(__name__)

# Circuit Breaker Configuration Constants
DEFAULT_MINIMUM_CALLS = 20
DEFAULT_RESET_TIMEOUT = 30
DEFAULT_NAME = "telemetry-circuit-breaker"

# Circuit Breaker State Constants (used in logging)
CIRCUIT_BREAKER_STATE_OPEN = "open"
CIRCUIT_BREAKER_STATE_CLOSED = "closed"
CIRCUIT_BREAKER_STATE_HALF_OPEN = "half-open"

# Logging Message Constants
LOG_CIRCUIT_BREAKER_STATE_CHANGED = "Circuit breaker state changed from %s to %s for %s"
LOG_CIRCUIT_BREAKER_OPENED = (
"Circuit breaker opened for %s - telemetry requests will be blocked"
)
LOG_CIRCUIT_BREAKER_CLOSED = (
"Circuit breaker closed for %s - telemetry requests will be allowed"
)
LOG_CIRCUIT_BREAKER_HALF_OPEN = (
"Circuit breaker half-open for %s - testing telemetry requests"
)


class CircuitBreakerStateListener(CircuitBreakerListener):
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only used for logging purposed for now

"""Listener for circuit breaker state changes."""

def before_call(self, cb: CircuitBreaker, func, *args, **kwargs) -> None:
"""Called before the circuit breaker calls a function."""
pass

def failure(self, cb: CircuitBreaker, exc: BaseException) -> None:
"""Called when a function called by the circuit breaker fails."""
pass

def success(self, cb: CircuitBreaker) -> None:
"""Called when a function called by the circuit breaker succeeds."""
pass

def state_change(self, cb: CircuitBreaker, old_state, new_state) -> None:
"""Called when the circuit breaker state changes."""
old_state_name = old_state.name if old_state else "None"
new_state_name = new_state.name if new_state else "None"

logger.info(
LOG_CIRCUIT_BREAKER_STATE_CHANGED, old_state_name, new_state_name, cb.name
)

if new_state_name == CIRCUIT_BREAKER_STATE_OPEN:
logger.warning(LOG_CIRCUIT_BREAKER_OPENED, cb.name)
elif new_state_name == CIRCUIT_BREAKER_STATE_CLOSED:
logger.info(LOG_CIRCUIT_BREAKER_CLOSED, cb.name)
elif new_state_name == CIRCUIT_BREAKER_STATE_HALF_OPEN:
logger.info(LOG_CIRCUIT_BREAKER_HALF_OPEN, cb.name)


@dataclass(frozen=True)
class CircuitBreakerConfig:
"""Configuration for circuit breaker behavior.

This class is immutable to prevent modification of circuit breaker settings.
All configuration values are set to constants defined at the module level.
"""

# Minimum number of calls before circuit can open
minimum_calls: int = DEFAULT_MINIMUM_CALLS

# Time to wait before trying to close circuit (in seconds)
reset_timeout: int = DEFAULT_RESET_TIMEOUT

# Name for the circuit breaker (for logging)
name: str = DEFAULT_NAME


class CircuitBreakerManager:
"""
Manages circuit breaker instances for telemetry requests.

This class provides a singleton pattern to manage circuit breaker instances
per host, ensuring that telemetry failures don't impact main SQL operations.
"""

_instances: Dict[str, CircuitBreaker] = {}
_lock = threading.RLock()
_config: Optional[CircuitBreakerConfig] = None

@classmethod
def initialize(cls, config: CircuitBreakerConfig) -> None:
"""
Initialize the circuit breaker manager with configuration.

Args:
config: Circuit breaker configuration
"""
with cls._lock:
cls._config = config
logger.debug("CircuitBreakerManager initialized with config: %s", config)

@classmethod
def get_circuit_breaker(cls, host: str) -> CircuitBreaker:
"""
Get or create a circuit breaker instance for the specified host.

Args:
host: The hostname for which to get the circuit breaker

Returns:
CircuitBreaker instance for the host
"""
if not cls._config:
# Return a no-op circuit breaker if not initialized
return cls._create_noop_circuit_breaker()

with cls._lock:
if host not in cls._instances:
cls._instances[host] = cls._create_circuit_breaker(host)
logger.debug("Created circuit breaker for host: %s", host)

return cls._instances[host]

@classmethod
def _create_circuit_breaker(cls, host: str) -> CircuitBreaker:
"""
Create a new circuit breaker instance for the specified host.

Args:
host: The hostname for the circuit breaker

Returns:
New CircuitBreaker instance
"""
config = cls._config
if config is None:
raise RuntimeError("CircuitBreakerManager not initialized")

# Create circuit breaker with configuration
breaker = CircuitBreaker(
fail_max=config.minimum_calls, # Number of failures before circuit opens
reset_timeout=config.reset_timeout,
name=f"{config.name}-{host}",
)

# Add state change listeners for logging
breaker.add_listener(CircuitBreakerStateListener())

return breaker

@classmethod
def _create_noop_circuit_breaker(cls) -> CircuitBreaker:
"""
Create a no-op circuit breaker that always allows calls.

Returns:
CircuitBreaker that never opens
"""
# Create a circuit breaker with very high thresholds so it never opens
breaker = CircuitBreaker(
fail_max=1000000, # Very high threshold
reset_timeout=1, # Short reset time
name="noop-circuit-breaker",
)
return breaker


def is_circuit_breaker_error(exception: Exception) -> bool:
"""
Check if an exception is a circuit breaker error.

Args:
exception: The exception to check

Returns:
True if the exception is a circuit breaker error
"""
return isinstance(exception, CircuitBreakerError)
Loading
Loading