Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions newsfragments/3749.internal.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Remove websockets deprecation warning by using the asyncio websocket provider
2 changes: 0 additions & 2 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

from websockets import (
WebSocketException,
)
from websockets.legacy.client import (
connect,
)

Expand Down
7 changes: 0 additions & 7 deletions web3/providers/async_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,6 @@
)

if TYPE_CHECKING:
from websockets.legacy.client import (
WebSocketClientProtocol,
)

from web3 import ( # noqa: F401
AsyncWeb3,
WebSocketProvider,
Expand Down Expand Up @@ -174,9 +170,6 @@ async def disconnect(self) -> None:
"Persistent connection providers must implement this method"
)

# WebSocket typing
_ws: "WebSocketClientProtocol"

# IPC typing
_reader: Optional[asyncio.StreamReader]
_writer: Optional[asyncio.StreamWriter]
Expand Down
34 changes: 20 additions & 14 deletions web3/providers/legacy_websocket.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from __future__ import (
annotations,
)

import asyncio
import json
import logging
Expand All @@ -9,22 +13,15 @@
TracebackType,
)
from typing import (
TYPE_CHECKING,
Any,
List,
Optional,
Tuple,
Type,
Union,
cast,
)

from eth_typing import (
URI,
)
from websockets.legacy.client import (
WebSocketClientProtocol,
connect,
)

from web3._utils.batching import (
sort_batch_response_by_response_ids,
Expand All @@ -43,6 +40,11 @@
RPCResponse,
)

if TYPE_CHECKING:
from websockets.legacy.client import (
WebSocketClientProtocol,
)

RESTRICTED_WEBSOCKET_KWARGS = {"uri", "loop"}
DEFAULT_WEBSOCKET_TIMEOUT = 30

Expand All @@ -66,18 +68,22 @@ def get_default_endpoint() -> URI:

class PersistentWebSocket:
def __init__(self, endpoint_uri: URI, websocket_kwargs: Any) -> None:
self.ws: Optional[WebSocketClientProtocol] = None
self.ws: WebSocketClientProtocol | None = None
self.endpoint_uri = endpoint_uri
self.websocket_kwargs = websocket_kwargs

async def __aenter__(self) -> WebSocketClientProtocol:
if self.ws is None:
from websockets.legacy.client import (
connect,
)

self.ws = await connect(uri=self.endpoint_uri, **self.websocket_kwargs)
return self.ws

async def __aexit__(
self,
exc_type: Type[BaseException],
exc_type: type[BaseException],
exc_val: BaseException,
exc_tb: TracebackType,
) -> None:
Expand All @@ -95,8 +101,8 @@ class LegacyWebSocketProvider(JSONBaseProvider):

def __init__(
self,
endpoint_uri: Optional[Union[URI, str]] = None,
websocket_kwargs: Optional[Any] = None,
endpoint_uri: URI | str | None = None,
websocket_kwargs: Any | None = None,
websocket_timeout: int = DEFAULT_WEBSOCKET_TIMEOUT,
**kwargs: Any,
) -> None:
Expand Down Expand Up @@ -144,8 +150,8 @@ def make_request(self, method: RPCEndpoint, params: Any) -> RPCResponse:
return future.result()

def make_batch_request(
self, requests: List[Tuple[RPCEndpoint, Any]]
) -> List[RPCResponse]:
self, requests: list[tuple[RPCEndpoint, Any]]
) -> list[RPCResponse]:
self.logger.debug(
"Making batch request WebSocket. URI: %s, Methods: %s",
self.endpoint_uri,
Expand Down
36 changes: 23 additions & 13 deletions web3/providers/persistent/websocket.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from __future__ import (
annotations,
)

import asyncio
import json
import logging
import os
from typing import (
Any,
Dict,
Optional,
Union,
)

from eth_typing import (
Expand All @@ -15,14 +16,21 @@
from toolz import (
merge,
)

# python3.8 supports up to version 13,
# which does not default to the asyncio implementation yet.
# For this reason connect and ClientConnection need to be imported
# from asyncio.client explicitly.
# When web3.py stops supporting python3.8,
# it'll be possible to use `from websockets import connect, ClientConnection`.
from websockets.asyncio.client import (
ClientConnection,
connect,
)
from websockets.exceptions import (
ConnectionClosedOK,
WebSocketException,
)
from websockets.legacy.client import (
WebSocketClientProtocol,
connect,
)

from web3.exceptions import (
PersistentConnectionClosedOK,
Expand Down Expand Up @@ -57,12 +65,14 @@ class WebSocketProvider(PersistentConnectionProvider):
logger = logging.getLogger("web3.providers.WebSocketProvider")
is_async: bool = True

_ws: ClientConnection

def __init__(
self,
endpoint_uri: Optional[Union[URI, str]] = None,
websocket_kwargs: Optional[Dict[str, Any]] = None,
endpoint_uri: URI | str | None = None,
websocket_kwargs: dict[str, Any] | None = None,
# uses binary frames by default
use_text_frames: Optional[bool] = False,
use_text_frames: bool | None = False,
# `PersistentConnectionProvider` kwargs can be passed through
**kwargs: Any,
) -> None:
Expand All @@ -72,7 +82,7 @@ def __init__(
)
super().__init__(**kwargs)
self.use_text_frames = use_text_frames
self._ws: Optional[WebSocketClientProtocol] = None
self._ws: ClientConnection | None = None

if not any(
self.endpoint_uri.startswith(prefix)
Expand Down Expand Up @@ -119,7 +129,7 @@ async def socket_send(self, request_data: bytes) -> None:
"Connection to websocket has not been initiated for the provider."
)

payload: Union[bytes, str] = request_data
payload: bytes | str = request_data
if self.use_text_frames:
payload = request_data.decode("utf-8")

Expand All @@ -136,7 +146,7 @@ async def _provider_specific_connect(self) -> None:

async def _provider_specific_disconnect(self) -> None:
# this should remain idempotent
if self._ws is not None and not self._ws.closed:
if self._ws is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we still need to figure out how to check for self._ws.closed, even if it's not that same method name.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, I did check this and I found out that self._ws.close() is idempotent already, what I didn't realise though is that web3.py still supports websockets >=10.0,<13.0, meaning me importing directly from websockets.asyncio would break that compatibility.
So I would need to rework my PR to take that into account.

I also noticed this issue #3679 which plans to move websockets bottom pin to >=14 #3530, which would essentially remove this problem completely.
Should I just do that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could either go down the path of supporting all the variations (aka <13; >=13,<14; >=14), i.e.

import websockets
websockets_version = tuple(int(x) for x in websockets.__version__.split(".") if x.isdigit())

if websockets_version < (13, 0):
    from websockets.legacy.client import (
        WebSocketClientProtocol as ClientConnection,  # we are safe to steal the name as the scope of interface we use is the same
        connect,
    )
else:
    # python3.8 supports up to version 13,
    # which does not default to the asyncio implementation yet.
    # For this reason connect and ClientConnection need to be imported
    # from asyncio.client explicitly.
    # When web3.py stops supporting python3.8,
    # it'll be possible to use `from websockets import connect, ClientConnection`.
    from websockets.asyncio.client import (
        ClientConnection,
        connect,
    )

or simply drop support for at least version <13.

Given the amount of usage of the websockets library interface either solution is fine, not a big deal.

await self._ws.close()
self._ws = None

Expand Down