Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions irpc-iroh/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@ n0-future = { workspace = true }
tracing-subscriber = { workspace = true, features = ["fmt"] }
irpc-derive = { version = "0.5.0", path = "../irpc-derive" }
clap = { version = "4.5.41", features = ["derive"] }
rand = "0.8"
144 changes: 91 additions & 53 deletions irpc-iroh/examples/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
//! * Manually implementing the connection loop
//! * Authenticating peers

use std::time::Duration;

use anyhow::Result;
use iroh::{protocol::Router, Endpoint, Watcher};
use iroh::{protocol::Router, Endpoint, NodeAddr, SecretKey, Watcher};

use self::storage::{StorageClient, StorageServer};

Expand All @@ -17,20 +19,28 @@ async fn main() -> Result<()> {
}

async fn remote() -> Result<()> {
let (server_router, server_addr) = {
let endpoint = Endpoint::builder().discovery_n0().bind().await?;
let server_secret_key = SecretKey::generate(&mut rand::rngs::OsRng);
let server_addr = NodeAddr::new(server_secret_key.public());
let start_server = async move || {
let endpoint = Endpoint::builder()
.secret_key(server_secret_key.clone())
.discovery_n0()
.bind()
.await?;
let server = StorageServer::new("secret".to_string());
let router = Router::builder(endpoint.clone())
.accept(StorageServer::ALPN, server.clone())
.spawn();
let addr = endpoint.node_addr().initialized().await;
(router, addr)
let _ = endpoint.home_relay().initialized().await;
// wait a bit for publishing to complete..
tokio::time::sleep(Duration::from_millis(500)).await;
anyhow::Ok(router)
};
let mut server_router = (start_server)().await?;

// correct authentication
let client_endpoint = Endpoint::builder().bind().await?;
let api = StorageClient::connect(client_endpoint, server_addr.clone());
api.auth("secret").await?;
let client_endpoint = Endpoint::builder().discovery_n0().bind().await?;
let api = StorageClient::connect(client_endpoint, server_addr.clone(), "secret");
api.set("hello".to_string(), "world".to_string()).await?;
api.set("goodbye".to_string(), "world".to_string()).await?;
let value = api.get("hello".to_string()).await?;
Expand All @@ -40,15 +50,21 @@ async fn remote() -> Result<()> {
println!("list value = {value:?}");
}

// invalid authentication
let client_endpoint = Endpoint::builder().bind().await?;
let api = StorageClient::connect(client_endpoint, server_addr.clone());
assert!(api.auth("bad").await.is_err());
assert!(api.get("hello".to_string()).await.is_err());
// restart server
server_router.shutdown().await?;
server_router = (start_server)().await?;

// reconnections work: client will transparently reauthenticate
println!("restarting server");
let value = api.get("hello".to_string()).await?;
println!("value = {value:?}");
api.set("hello".to_string(), "world".to_string()).await?;
let value = api.get("hello".to_string()).await?;
println!("value = {value:?}");

// no authentication
// invalid authentication
let client_endpoint = Endpoint::builder().bind().await?;
let api = StorageClient::connect(client_endpoint, server_addr);
let api = StorageClient::connect(client_endpoint, server_addr.clone(), "bad");
assert!(api.get("hello".to_string()).await.is_err());

drop(server_router);
Expand All @@ -65,17 +81,16 @@ mod storage {
sync::{Arc, Mutex},
};

use anyhow::Result;
use anyhow::{anyhow, Result};
use iroh::{
endpoint::Connection,
protocol::{AcceptError, ProtocolHandler},
Endpoint,
};
use irpc::{
channel::{mpsc, oneshot},
rpc_requests, Client, WithChannels,
rpc_requests, Client, Request, RequestError, WithChannels,
};
// Import the macro
use irpc_iroh::{read_request, IrohRemoteConnection};
use serde::{Deserialize, Serialize};
use tracing::info;
Expand Down Expand Up @@ -109,7 +124,8 @@ mod storage {
#[rpc_requests(message = StorageMessage)]
#[derive(Serialize, Deserialize, Debug)]
enum StorageProtocol {
#[rpc(tx=oneshot::Sender<Result<(), String>>)]
// Connection will be closed if auth fails.
#[rpc(tx=oneshot::Sender<()>)]
Auth(Auth),
#[rpc(tx=oneshot::Sender<Option<String>>)]
Get(Get),
Expand All @@ -129,31 +145,29 @@ mod storage {

impl ProtocolHandler for StorageServer {
async fn accept(&self, conn: Connection) -> Result<(), AcceptError> {
let mut authed = false;
while let Some(msg) = read_request::<StorageProtocol>(&conn).await? {
match msg {
StorageMessage::Auth(msg) => {
let WithChannels { inner, tx, .. } = msg;
if authed {
conn.close(1u32.into(), b"invalid message");
break;
} else if inner.token != self.auth_token {
conn.close(1u32.into(), b"permission denied");
break;
} else {
authed = true;
tx.send(Ok(())).await.ok();
}
}
msg => {
if !authed {
conn.close(1u32.into(), b"permission denied");
break;
} else {
self.handle_authenticated(msg).await;
}
}
// read first message: must be auth!
let msg = read_request::<StorageProtocol>(&conn).await?;
let auth_ok = if let Some(StorageMessage::Auth(msg)) = msg {
let WithChannels { inner, tx, .. } = msg;
if inner.token == self.auth_token {
tx.send(()).await.ok();
true
} else {
false
}
} else {
false
};

// if not authenticated: close connection immediately.
if !auth_ok {
conn.close(1u32.into(), b"permission denied");
return Ok(());
}

// now the connection is authenticated and we can handle all subsequent requests.
while let Some(msg) = read_request::<StorageProtocol>(&conn).await? {
self.handle_request(msg).await;
}
conn.closed().await;
Ok(())
Expand All @@ -170,7 +184,7 @@ mod storage {
}
}

async fn handle_authenticated(&self, msg: StorageMessage) {
async fn handle_request(&self, msg: StorageMessage) {
match msg {
StorageMessage::Auth(_) => unreachable!("handled in ProtocolHandler::accept"),
StorageMessage::Get(get) => {
Expand Down Expand Up @@ -218,39 +232,63 @@ mod storage {
}

pub struct StorageClient {
api_token: String,
inner: Client<StorageProtocol>,
}

impl StorageClient {
pub const ALPN: &[u8] = ALPN;

pub fn connect(endpoint: Endpoint, addr: impl Into<iroh::NodeAddr>) -> StorageClient {
pub fn connect(
endpoint: Endpoint,
addr: impl Into<iroh::NodeAddr>,
api_token: &str,
) -> StorageClient {
let conn = IrohRemoteConnection::new(endpoint, addr.into(), Self::ALPN.to_vec());
StorageClient {
api_token: api_token.to_string(),
inner: Client::boxed(conn),
}
}

pub async fn auth(&self, token: &str) -> Result<(), anyhow::Error> {
self.inner
async fn authenticated_request(&self) -> Result<Request<StorageProtocol>, irpc::Error> {
let request = self.inner.request().await?;

// if the connection is not new: no need to reauthenticate.
if !request.is_new_connection() {
return Ok(request);
}

// if this is a new connection: use this request to send an auth message.
request
.rpc(Auth {
token: token.to_string(),
token: self.api_token.clone(),
})
.await?
.map_err(|err| anyhow::anyhow!(err))
.await?;
// and create a new request for the actual call.
let request = self.inner.request().await?;
// if this *again* created a new connection, we error out.
if request.is_new_connection() {
Err(RequestError::Other(anyhow!("Connection is reconnecting too often")).into())
} else {
Ok(request)
}
}

pub async fn get(&self, key: String) -> Result<Option<String>, irpc::Error> {
self.inner.rpc(Get { key }).await
self.authenticated_request().await?.rpc(Get { key }).await
}

pub async fn list(&self) -> Result<mpsc::Receiver<String>, irpc::Error> {
self.inner.server_streaming(List, 10).await
self.authenticated_request()
.await?
.server_streaming(List, 10)
.await
}

pub async fn set(&self, key: String, value: String) -> Result<(), irpc::Error> {
let msg = Set { key, value };
self.inner.rpc(msg).await
self.authenticated_request().await?.rpc(msg).await
}
}
}
25 changes: 15 additions & 10 deletions irpc-iroh/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ use iroh::{
use irpc::{
channel::RecvError,
rpc::{
Handler, RemoteConnection, RemoteService, ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED,
MAX_MESSAGE_SIZE,
Handler, RemoteConnection, RemoteService, RemoteStreams,
ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED, MAX_MESSAGE_SIZE,
},
util::AsyncReadVarintExt,
LocalSender, RequestError,
Expand Down Expand Up @@ -60,27 +60,32 @@ impl RemoteConnection for IrohRemoteConnection {
Box::new(self.clone())
}

fn open_bi(&self) -> BoxFuture<std::result::Result<(SendStream, RecvStream), RequestError>> {
fn open_bi(&self) -> BoxFuture<std::result::Result<RemoteStreams, RequestError>> {
let this = self.0.clone();
Box::pin(async move {
let mut guard = this.connection.lock().await;
let pair = match guard.as_mut() {
Some(conn) => {
// try to reuse the connection
match conn.open_bi().await {
Ok(pair) => pair,
Ok(pair) => RemoteStreams::with_reused(pair),
Err(_) => {
// try with a new connection, just once
*guard = None;
connect_and_open_bi(&this.endpoint, &this.addr, &this.alpn, guard)
.await
.map_err(RequestError::Other)?
let pair =
connect_and_open_bi(&this.endpoint, &this.addr, &this.alpn, guard)
.await
.map_err(RequestError::Other)?;
RemoteStreams::with_new(pair)
}
}
}
None => connect_and_open_bi(&this.endpoint, &this.addr, &this.alpn, guard)
.await
.map_err(RequestError::Other)?,
None => {
let pair = connect_and_open_bi(&this.endpoint, &this.addr, &this.alpn, guard)
.await
.map_err(RequestError::Other)?;
RemoteStreams::with_new(pair)
}
};
Ok(pair)
})
Expand Down
Loading
Loading