Skip to content

Commit f3143e2

Browse files
authored
Fix thread-safety issue while configuring SSLContext
1 parent 079f0eb commit f3143e2

File tree

5 files changed

+73
-4
lines changed

5 files changed

+73
-4
lines changed

noxfile.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import sys
23

34
import nox
45

@@ -15,6 +16,11 @@ def test(session):
1516
session.env["YARL_NO_EXTENSIONS"] = "1"
1617
session.env["FROZENLIST_NO_EXTENSIONS"] = "1"
1718

19+
# This would need to be updated if we added PyPy
20+
# to our default Python versions above. Right now
21+
# PyPy is only used on CI.
22+
pypy = session.python is None and sys.implementation.name == "pypy"
23+
1824
session.install("-rdev-requirements.txt", ".")
1925
session.run("pip", "freeze")
2026
session.run(
@@ -23,6 +29,7 @@ def test(session):
2329
"-s",
2430
"-rs",
2531
"--no-flaky-report",
32+
*(("--config-file=pypy-pytest.ini",) if pypy else ()),
2633
"--max-runs=3",
2734
*(session.posargs or ("tests/",)),
2835
)

pypy-pytest.ini

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# This pytest.ini is only for PyPy
2+
# to suppress 'ResourceWarning' from aiohttp.
3+
4+
[pytest]
5+
asyncio_mode = strict
6+
asyncio_default_fixture_loop_scope = function
7+
filterwarnings =
8+
error
9+
# See: aio-libs/aiohttp#7545
10+
ignore:.*datetime.utcfromtimestamp().*:DeprecationWarning
11+
# PyPy and aiohttp don't play nice.
12+
ignore:.*:ResourceWarning
13+
markers =
14+
internet: test requires Internet access

src/truststore/_api.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
import contextlib
12
import os
23
import platform
34
import socket
45
import ssl
56
import sys
7+
import threading
68
import typing
79

810
import _ssl
@@ -84,6 +86,7 @@ def __class__(self) -> type:
8486

8587
def __init__(self, protocol: int = None) -> None: # type: ignore[assignment]
8688
self._ctx = _original_SSLContext(protocol)
89+
self._ctx_lock = threading.Lock()
8790

8891
class TruststoreSSLObject(ssl.SSLObject):
8992
# This object exists because wrap_bio() doesn't
@@ -106,10 +109,15 @@ def wrap_socket(
106109
server_hostname: str | None = None,
107110
session: ssl.SSLSession | None = None,
108111
) -> ssl.SSLSocket:
109-
# Use a context manager here because the
110-
# inner SSLContext holds on to our state
111-
# but also does the actual handshake.
112-
with _configure_context(self._ctx):
112+
113+
# We need to lock around the .__enter__()
114+
# but we don't need to lock within the
115+
# context manager, so we need to expand the
116+
# syntactic sugar of the `with` statement.
117+
with contextlib.ExitStack() as stack:
118+
with self._ctx_lock:
119+
stack.enter_context(_configure_context(self._ctx))
120+
113121
ssl_sock = self._ctx.wrap_socket(
114122
sock,
115123
server_side=server_side,

tests/conftest.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,4 +172,6 @@ async def handler(request: web.Request) -> web.Response:
172172
yield Server(host="localhost", port=port)
173173
finally:
174174
await site.stop()
175+
# Wait for the server to actually close.
176+
await site._server.wait_closed()
175177
await runner.cleanup()

tests/test_threading.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import asyncio
2+
import socket
3+
import ssl
4+
import threading
5+
6+
import pytest
7+
8+
import truststore
9+
10+
11+
def wrap_and_close_sockets(ctx: truststore.SSLContext, host: str, port: int) -> None:
12+
for _ in range(100):
13+
sock = None
14+
try:
15+
sock = socket.create_connection((host, port))
16+
sock = ctx.wrap_socket(sock, server_hostname=host)
17+
finally:
18+
if sock:
19+
sock.close()
20+
21+
22+
@pytest.mark.asyncio
23+
async def test_threading(server):
24+
def run_threads():
25+
ctx = truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
26+
threads = [
27+
threading.Thread(
28+
target=wrap_and_close_sockets, args=(ctx, server.host, server.port)
29+
)
30+
for _ in range(16)
31+
]
32+
for t in threads:
33+
t.start()
34+
for t in threads:
35+
t.join()
36+
37+
thread = asyncio.to_thread(run_threads)
38+
await thread

0 commit comments

Comments
 (0)