Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ required-features = ["derive"]
name = "storage"
required-features = ["rpc", "quinn_endpoint_setup"]

[[example]]
name = "nested"
required-features = ["rpc", "derive", "quinn_endpoint_setup"]

[workspace]
members = ["irpc-derive", "irpc-iroh"]

Expand Down
96 changes: 96 additions & 0 deletions examples/nested.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
use std::collections::HashMap;

use irpc::{channel::oneshot, rpc_requests, Client};
use serde::{Deserialize, Serialize};

#[rpc_requests(message = TestMessage)]
#[derive(Debug, Serialize, Deserialize)]
enum TestProtocol {
#[rpc(tx = oneshot::Sender<()>)]
Put(PutRequest),
#[rpc(tx = oneshot::Sender<Option<String>>)]
Get(GetRequest),
#[rpc(nested = NestedMessage)]
Nested(NestedProtocol),
}

#[derive(Debug, Serialize, Deserialize)]
struct PutRequest {
key: String,
value: String,
}

#[derive(Debug, Serialize, Deserialize)]
struct GetRequest {
key: String,
}

#[rpc_requests(message = NestedMessage)]
#[derive(Debug, Serialize, Deserialize)]
enum NestedProtocol {
#[rpc(tx = oneshot::Sender<()>)]
Put(PutRequest2),
}

#[derive(Debug, Serialize, Deserialize)]
struct PutRequest2 {
key: String,
value: u32,
}

#[tokio::main]
async fn main() -> anyhow::Result<()> {
let (tx, rx) = tokio::sync::mpsc::channel(10);
tokio::task::spawn(actor(rx));
let client: Client<TestProtocol> = Client::from(tx);
client
.rpc(PutRequest {
key: "foo".to_string(),
value: "bar".to_string(),
})
.await?;
let v = client
.rpc(GetRequest {
key: "foo".to_string(),
})
.await?;
println!("{v:?}");
assert_eq!(v.as_deref(), Some("bar"));
client
.map::<NestedProtocol>()
.rpc(PutRequest2 {
key: "foo".to_string(),
value: 22,
})
.await?;
let v = client
.rpc(GetRequest {
key: "foo".to_string(),
})
.await?;
println!("{v:?}");
assert_eq!(v.as_deref(), Some("22"));
Ok(())
}

async fn actor(mut rx: tokio::sync::mpsc::Receiver<TestMessage>) {
let mut store = HashMap::new();
while let Some(msg) = rx.recv().await {
match msg {
TestMessage::Put(msg) => {
store.insert(msg.inner.key, msg.inner.value);
msg.tx.send(()).await.ok();
}
TestMessage::Get(msg) => {
let res = store.get(&msg.key).cloned();
msg.tx.send(res).await.ok();
}
TestMessage::Nested(inner) => match inner {
NestedMessage::Put(msg) => {
store.insert(msg.inner.key, msg.inner.value.to_string());
msg.tx.send(()).await.ok();
}
},
}
}
}
93 changes: 86 additions & 7 deletions examples/storage.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
//! This example does not use the `rpc_requests` macro and instead implements
//! everything manually.

use std::{
collections::BTreeMap,
net::{Ipv4Addr, SocketAddr, SocketAddrV4},
Expand All @@ -8,12 +11,14 @@ use irpc::{
channel::{mpsc, none::NoReceiver, oneshot},
rpc::{listen, RemoteService},
util::{make_client_endpoint, make_server_endpoint},
Channels, Client, Request, Service, WithChannels,
Channels, Client, MappedClient, Request, Service, WithChannels,
};
use n0_future::task::{self, AbortOnDropHandle};
use serde::{Deserialize, Serialize};
use tracing::info;

use self::shout_crate::*;

impl Service for StorageProtocol {
type Message = StorageMessage;
}
Expand Down Expand Up @@ -52,13 +57,15 @@ enum StorageProtocol {
Get(Get),
Set(Set),
List(List),
Shout(ShoutProtocol),
}

#[derive(derive_more::From)]
enum StorageMessage {
Get(WithChannels<Get, StorageProtocol>),
Set(WithChannels<Set, StorageProtocol>),
List(WithChannels<List, StorageProtocol>),
Shout(ShoutMessage),
}

impl RemoteService for StorageProtocol {
Expand All @@ -67,6 +74,64 @@ impl RemoteService for StorageProtocol {
StorageProtocol::Get(msg) => WithChannels::from((msg, tx, rx)).into(),
StorageProtocol::Set(msg) => WithChannels::from((msg, tx, rx)).into(),
StorageProtocol::List(msg) => WithChannels::from((msg, tx, rx)).into(),
StorageProtocol::Shout(msg) => msg.with_remote_channels(rx, tx).into(),
}
}
}

/// This is a protocol that could live in a different crate.
mod shout_crate {
use irpc::{
channel::{none::NoReceiver, oneshot},
rpc::RemoteService,
Channels, Service, WithChannels,
};
use serde::{Deserialize, Serialize};
use tracing::info;

#[derive(derive_more::From, Serialize, Deserialize, Debug)]
pub enum ShoutProtocol {
Shout(Shout),
}

impl Service for ShoutProtocol {
type Message = ShoutMessage;
}

#[derive(Debug, Serialize, Deserialize)]
pub struct Shout {
pub key: String,
}

impl Channels<ShoutProtocol> for Shout {
type Rx = NoReceiver;
type Tx = oneshot::Sender<String>;
}

#[derive(derive_more::From)]
pub enum ShoutMessage {
Shout(WithChannels<Shout, ShoutProtocol>),
}

impl RemoteService for ShoutProtocol {
fn with_remote_channels(
self,
rx: quinn::RecvStream,
tx: quinn::SendStream,
) -> Self::Message {
match self {
ShoutProtocol::Shout(msg) => WithChannels::from((msg, tx, rx)).into(),
}
}
}

pub async fn handle_message(msg: ShoutMessage) {
match msg {
ShoutMessage::Shout(msg) => {
info!("shout.shout: {msg:?}");
let WithChannels { tx, inner, .. } = msg;
tx.send(inner.key.to_uppercase()).await.ok();
}
}
}
}
Expand All @@ -84,9 +149,9 @@ impl StorageActor {
state: BTreeMap::new(),
};
n0_future::task::spawn(actor.run());
StorageApi {
inner: Client::local(tx),
}
let inner = Client::local(tx);
let shout = inner.map().to_owned();
StorageApi { inner, shout }
}

async fn run(mut self) {
Expand Down Expand Up @@ -117,18 +182,22 @@ impl StorageActor {
}
}
}
// We delegate these messages to the handler in the other "crate".
StorageMessage::Shout(msg) => shout_crate::handle_message(msg).await,
}
}
}

struct StorageApi {
inner: Client<StorageProtocol>,
shout: MappedClient<'static, StorageProtocol, ShoutProtocol>,
}

impl StorageApi {
pub fn connect(endpoint: quinn::Endpoint, addr: SocketAddr) -> anyhow::Result<StorageApi> {
Ok(StorageApi {
inner: Client::quinn(endpoint, addr),
})
let inner = Client::quinn(endpoint, addr);
let shout = inner.map().to_owned();
Ok(StorageApi { inner, shout })
}

pub fn listen(&self, endpoint: quinn::Endpoint) -> anyhow::Result<AbortOnDropHandle<()>> {
Expand Down Expand Up @@ -185,6 +254,12 @@ impl StorageApi {
}
}
}

pub async fn shout(&self, key: String) -> anyhow::Result<String> {
let msg = Shout { key };
let res = self.shout.rpc(msg).await?;
Ok(res)
}
}

async fn local() -> anyhow::Result<()> {
Expand All @@ -198,6 +273,8 @@ async fn local() -> anyhow::Result<()> {
println!("list value = {value:?}");
}
println!("value = {value:?}");
let res = api.shout("hello".to_string()).await?;
println!("shout.shout = {res:?}");
Ok(())
}

Expand All @@ -222,6 +299,8 @@ async fn remote() -> anyhow::Result<()> {
while let Some(value) = list.recv().await? {
println!("list value = {value:?}");
}
let shout = api.shout("hello".to_string()).await?;
println!("shout.shout = {shout:?}");
drop(handle);
Ok(())
}
Expand Down
Loading