Skip to content

Websockets #1

@davidpdrsn

Description

@davidpdrsn

Something along the lines of

#![allow(unused_imports)]

use bytes::Bytes;
use futures::prelude::*;
use futures::SinkExt;
use http::{header::HeaderName, HeaderValue, Request, Response, StatusCode};
use http_body::Empty;
use hyper::{
    server::conn::AddrStream,
    upgrade::{OnUpgrade, Upgraded},
    Body,
};
use sha1::{Digest, Sha1};
use std::future::Future;
use std::pin::Pin;
use std::task::Context;
use std::{convert::Infallible, task::Poll};
use tokio_tungstenite::{
    tungstenite::protocol::{self, WebSocketConfig},
    WebSocketStream,
};
use tower::{make::Shared, ServiceBuilder};
use tower::{BoxError, Service};
use tower_http::trace::{DefaultMakeSpan, DefaultOnResponse, TraceLayer};
use tower_http::LatencyUnit;

#[tokio::main]
async fn main() {
    tracing_subscriber::fmt::init();

    let svc = ServiceBuilder::new()
        .layer(TraceLayer::new_for_http())
        .service(WebSocketUpgrade::new(handle_socket));

    let addr = std::net::SocketAddr::from(([0, 0, 0, 0], 3000));

    hyper::Server::bind(&addr)
        .serve(Shared::new(svc))
        .await
        .unwrap();
}

async fn handle_socket(mut socket: WebSocket) {
    while let Some(msg) = socket.recv().await {
        println!("received message: {:?}", msg);
    }
}

#[derive(Debug, Clone)]
pub struct WebSocketUpgrade<F> {
    callback: F,
    config: WebSocketConfig,
}

impl<F> WebSocketUpgrade<F> {
    pub fn new(callback: F) -> Self {
        Self {
            callback,
            config: WebSocketConfig::default(),
        }
    }
}

impl<ReqBody, F, Fut> Service<Request<ReqBody>> for WebSocketUpgrade<F>
where
    F: FnOnce(WebSocket) -> Fut + Clone + Send + 'static,
    Fut: Future<Output = ()> + Send + 'static,
{
    type Response = Response<Empty<Bytes>>;
    type Error = BoxError;
    type Future = ResponseFuture;

    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        Poll::Ready(Ok(()))
    }

    fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
        // TODO(david): missing `upgrade` should return "bad request"

        if !header_eq(
            &req,
            HeaderName::from_static("upgrade"),
            HeaderValue::from_static("websocket"),
        ) {
            todo!()
        }

        if !header_eq(
            &req,
            HeaderName::from_static("sec-websocket-version"),
            HeaderValue::from_static("13"),
        ) {
            todo!()
        }

        let key = if let Some(key) = req.headers_mut().remove("sec-websocket-key") {
            key
        } else {
            todo!()
        };

        let on_upgrade = req.extensions_mut().remove::<OnUpgrade>().unwrap();

        let config = self.config;
        let callback = self.callback.clone();

        tokio::spawn(async move {
            let upgraded = on_upgrade.await.unwrap();
            let socket =
                WebSocketStream::from_raw_socket(upgraded, protocol::Role::Server, Some(config))
                    .await;
            let socket = WebSocket { inner: socket };
            callback(socket).await;
        });

        ResponseFuture { key: Some(key) }
    }
}

#[derive(Debug)]
pub struct ResponseFuture {
    key: Option<HeaderValue>,
}

impl Future for ResponseFuture {
    type Output = Result<Response<Empty<Bytes>>, BoxError>;

    fn poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
        let res = Response::builder()
            .status(StatusCode::SWITCHING_PROTOCOLS)
            .header(
                http::header::CONNECTION,
                HeaderValue::from_str("upgrade").unwrap(),
            )
            .header(
                http::header::UPGRADE,
                HeaderValue::from_str("websocket").unwrap(),
            )
            .header(
                http::header::SEC_WEBSOCKET_ACCEPT,
                sign(self.as_mut().key.take().unwrap().as_bytes()),
            )
            .body(Empty::new())
            .unwrap();

        Poll::Ready(Ok(res))
    }
}

fn header_eq<B>(req: &Request<B>, key: HeaderName, value: HeaderValue) -> bool {
    let header = if let Some(x) = req.headers().get(&key) {
        x
    } else {
        return false;
    };
    header == value
}

// from https://github.com/hyperium/headers/blob/master/src/common/sec_websocket_accept.rs#L38
fn sign(key: &[u8]) -> HeaderValue {
    let mut sha1 = Sha1::default();
    sha1.update(key);
    sha1.update(&b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"[..]);
    let b64 = Bytes::from(base64::encode(&sha1.finalize()));
    HeaderValue::from_maybe_shared(b64).expect("base64 is a valid value")
}

#[derive(Debug)]
pub struct WebSocket {
    inner: WebSocketStream<Upgraded>,
}

impl WebSocket {
    pub async fn recv(&mut self) -> Option<Result<protocol::Message, BoxError>> {
        self.inner.next().await.map(|opt| opt.map_err(Into::into))
    }
}

// TODO(david): impl Stream<Message>
// TODO(david): WebSocket::close

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions