-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Closed
Description
Something along the lines of
#![allow(unused_imports)]
use bytes::Bytes;
use futures::prelude::*;
use futures::SinkExt;
use http::{header::HeaderName, HeaderValue, Request, Response, StatusCode};
use http_body::Empty;
use hyper::{
server::conn::AddrStream,
upgrade::{OnUpgrade, Upgraded},
Body,
};
use sha1::{Digest, Sha1};
use std::future::Future;
use std::pin::Pin;
use std::task::Context;
use std::{convert::Infallible, task::Poll};
use tokio_tungstenite::{
tungstenite::protocol::{self, WebSocketConfig},
WebSocketStream,
};
use tower::{make::Shared, ServiceBuilder};
use tower::{BoxError, Service};
use tower_http::trace::{DefaultMakeSpan, DefaultOnResponse, TraceLayer};
use tower_http::LatencyUnit;
#[tokio::main]
async fn main() {
tracing_subscriber::fmt::init();
let svc = ServiceBuilder::new()
.layer(TraceLayer::new_for_http())
.service(WebSocketUpgrade::new(handle_socket));
let addr = std::net::SocketAddr::from(([0, 0, 0, 0], 3000));
hyper::Server::bind(&addr)
.serve(Shared::new(svc))
.await
.unwrap();
}
async fn handle_socket(mut socket: WebSocket) {
while let Some(msg) = socket.recv().await {
println!("received message: {:?}", msg);
}
}
#[derive(Debug, Clone)]
pub struct WebSocketUpgrade<F> {
callback: F,
config: WebSocketConfig,
}
impl<F> WebSocketUpgrade<F> {
pub fn new(callback: F) -> Self {
Self {
callback,
config: WebSocketConfig::default(),
}
}
}
impl<ReqBody, F, Fut> Service<Request<ReqBody>> for WebSocketUpgrade<F>
where
F: FnOnce(WebSocket) -> Fut + Clone + Send + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
type Response = Response<Empty<Bytes>>;
type Error = BoxError;
type Future = ResponseFuture;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
// TODO(david): missing `upgrade` should return "bad request"
if !header_eq(
&req,
HeaderName::from_static("upgrade"),
HeaderValue::from_static("websocket"),
) {
todo!()
}
if !header_eq(
&req,
HeaderName::from_static("sec-websocket-version"),
HeaderValue::from_static("13"),
) {
todo!()
}
let key = if let Some(key) = req.headers_mut().remove("sec-websocket-key") {
key
} else {
todo!()
};
let on_upgrade = req.extensions_mut().remove::<OnUpgrade>().unwrap();
let config = self.config;
let callback = self.callback.clone();
tokio::spawn(async move {
let upgraded = on_upgrade.await.unwrap();
let socket =
WebSocketStream::from_raw_socket(upgraded, protocol::Role::Server, Some(config))
.await;
let socket = WebSocket { inner: socket };
callback(socket).await;
});
ResponseFuture { key: Some(key) }
}
}
#[derive(Debug)]
pub struct ResponseFuture {
key: Option<HeaderValue>,
}
impl Future for ResponseFuture {
type Output = Result<Response<Empty<Bytes>>, BoxError>;
fn poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
let res = Response::builder()
.status(StatusCode::SWITCHING_PROTOCOLS)
.header(
http::header::CONNECTION,
HeaderValue::from_str("upgrade").unwrap(),
)
.header(
http::header::UPGRADE,
HeaderValue::from_str("websocket").unwrap(),
)
.header(
http::header::SEC_WEBSOCKET_ACCEPT,
sign(self.as_mut().key.take().unwrap().as_bytes()),
)
.body(Empty::new())
.unwrap();
Poll::Ready(Ok(res))
}
}
fn header_eq<B>(req: &Request<B>, key: HeaderName, value: HeaderValue) -> bool {
let header = if let Some(x) = req.headers().get(&key) {
x
} else {
return false;
};
header == value
}
// from https://github.com/hyperium/headers/blob/master/src/common/sec_websocket_accept.rs#L38
fn sign(key: &[u8]) -> HeaderValue {
let mut sha1 = Sha1::default();
sha1.update(key);
sha1.update(&b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"[..]);
let b64 = Bytes::from(base64::encode(&sha1.finalize()));
HeaderValue::from_maybe_shared(b64).expect("base64 is a valid value")
}
#[derive(Debug)]
pub struct WebSocket {
inner: WebSocketStream<Upgraded>,
}
impl WebSocket {
pub async fn recv(&mut self) -> Option<Result<protocol::Message, BoxError>> {
self.inner.next().await.map(|opt| opt.map_err(Into::into))
}
}
// TODO(david): impl Stream<Message>
// TODO(david): WebSocket::close
Metadata
Metadata
Assignees
Labels
No labels