diff --git a/.tool-versions b/.tool-versions index ee114d2..35b639f 100644 --- a/.tool-versions +++ b/.tool-versions @@ -1 +1 @@ -rust 1.85.1 +rust 1.88.0 diff --git a/Cargo.lock b/Cargo.lock index 9d6c50d..a598332 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1201,6 +1201,8 @@ version = "0.1.0" dependencies = [ "futures", "spawned-rt", + "tokio", + "tokio-stream", "tracing", ] @@ -1210,6 +1212,7 @@ version = "0.1.0" dependencies = [ "crossbeam", "tokio", + "tokio-stream", "tokio-util", "tracing", "tracing-subscriber", @@ -1342,6 +1345,18 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-stream" +version = "0.1.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eca58d7bba4a75707817a2c44174253f9236b2d5fbd055602e9d5c07c139a047" +dependencies = [ + "futures-core", + "pin-project-lite", + "tokio", + "tokio-util", +] + [[package]] name = "tokio-util" version = "0.7.15" diff --git a/concurrency/Cargo.toml b/concurrency/Cargo.toml index 57c225d..eae75a3 100644 --- a/concurrency/Cargo.toml +++ b/concurrency/Cargo.toml @@ -8,5 +8,10 @@ spawned-rt = { workspace = true } tracing = { workspace = true } futures = "0.3.1" +[dev-dependencies] +# This tokio imports are only used in tests, we should not use them in the library code. +tokio-stream = { version = "0.1.17" } +tokio = { version = "1", features = ["full"] } + [lib] -path = "./src/lib.rs" \ No newline at end of file +path = "./src/lib.rs" diff --git a/concurrency/src/tasks/mod.rs b/concurrency/src/tasks/mod.rs index 54f35fa..201c7d2 100644 --- a/concurrency/src/tasks/mod.rs +++ b/concurrency/src/tasks/mod.rs @@ -3,11 +3,15 @@ mod gen_server; mod process; +mod stream; mod time; +#[cfg(test)] +mod stream_tests; #[cfg(test)] mod timer_tests; pub use gen_server::{CallResponse, CastResponse, GenServer, GenServerHandle, GenServerInMsg}; pub use process::{send, Process, ProcessInfo}; +pub use stream::spawn_listener; pub use time::{send_after, send_interval}; diff --git a/concurrency/src/tasks/stream.rs b/concurrency/src/tasks/stream.rs new file mode 100644 index 0000000..bb3df3b --- /dev/null +++ b/concurrency/src/tasks/stream.rs @@ -0,0 +1,57 @@ +use crate::tasks::{GenServer, GenServerHandle}; +use futures::{future::select, Stream, StreamExt}; +use spawned_rt::tasks::{CancellationToken, JoinHandle}; + +/// Spawns a listener that listens to a stream and sends messages to a GenServer. +/// +/// Items sent through the stream are required to be wrapped in a Result type. +/// +/// This function returns a handle to the spawned task and a cancellation token +/// to stop it. +pub fn spawn_listener( + mut handle: GenServerHandle, + message_builder: F, + mut stream: S, +) -> (JoinHandle<()>, CancellationToken) +where + T: GenServer + 'static, + F: Fn(I) -> T::CastMsg + Send + 'static + std::marker::Sync, + I: Send, + E: std::fmt::Debug + Send, + S: Unpin + Send + Stream> + 'static, +{ + let cancelation_token = CancellationToken::new(); + let cloned_token = cancelation_token.clone(); + let join_handle = spawned_rt::tasks::spawn(async move { + let result = select( + Box::pin(cloned_token.cancelled()), + Box::pin(async { + loop { + match stream.next().await { + Some(Ok(i)) => match handle.cast(message_builder(i)).await { + Ok(_) => tracing::trace!("Message sent successfully"), + Err(e) => { + tracing::error!("Failed to send message: {e:?}"); + break; + } + }, + Some(Err(e)) => { + tracing::trace!("Received Error in msg {e:?}"); + break; + } + None => { + tracing::trace!("Stream finished"); + break; + } + } + } + }), + ) + .await; + match result { + futures::future::Either::Left(_) => tracing::trace!("Listener cancelled"), + futures::future::Either::Right(_) => (), // Stream finished or errored out + } + }); + (join_handle, cancelation_token) +} diff --git a/concurrency/src/tasks/stream_tests.rs b/concurrency/src/tasks/stream_tests.rs new file mode 100644 index 0000000..5236363 --- /dev/null +++ b/concurrency/src/tasks/stream_tests.rs @@ -0,0 +1,174 @@ +use std::time::Duration; + +use spawned_rt::tasks::{self as rt, BroadcastStream, ReceiverStream}; + +use crate::tasks::{ + stream::spawn_listener, CallResponse, CastResponse, GenServer, GenServerHandle, +}; + +type SummatoryHandle = GenServerHandle; + +struct Summatory; + +type SummatoryState = u16; + +#[derive(Clone)] +struct UpdateSumatory { + added_value: u16, +} + +impl Summatory { + pub async fn get_value(server: &mut SummatoryHandle) -> Result { + server.call(()).await.map_err(|_| ()) + } +} + +impl GenServer for Summatory { + type CallMsg = (); // We only handle one type of call, so there is no need for a specific message type. + type CastMsg = UpdateSumatory; + type OutMsg = SummatoryState; + type State = SummatoryState; + type Error = (); + + fn new() -> Self { + Self + } + + async fn handle_cast( + &mut self, + message: Self::CastMsg, + _handle: &GenServerHandle, + state: Self::State, + ) -> CastResponse { + let new_state = state + message.added_value; + CastResponse::NoReply(new_state) + } + + async fn handle_call( + &mut self, + _message: Self::CallMsg, + _handle: &SummatoryHandle, + state: Self::State, + ) -> CallResponse { + let current_value = state; + CallResponse::Reply(state, current_value) + } +} + +// In this example, the stream sends u8 values, which are converted to the type +// supported by the GenServer (UpdateSumatory / u16). +fn message_builder(value: u8) -> UpdateSumatory { + UpdateSumatory { + added_value: value as u16, + } +} + +#[test] +pub fn test_sum_numbers_from_stream() { + let runtime = rt::Runtime::new().unwrap(); + runtime.block_on(async move { + let mut summatory_handle = Summatory::start(0); + let stream = tokio_stream::iter(vec![1u8, 2, 3, 4, 5].into_iter().map(Ok::)); + + spawn_listener(summatory_handle.clone(), message_builder, stream); + + // Wait for 1 second so the whole stream is processed + rt::sleep(Duration::from_secs(1)).await; + + let val = Summatory::get_value(&mut summatory_handle).await.unwrap(); + assert_eq!(val, 15); + }) +} + +#[test] +pub fn test_sum_numbers_from_channel() { + let runtime = rt::Runtime::new().unwrap(); + runtime.block_on(async move { + let mut summatory_handle = Summatory::start(0); + let (tx, rx) = spawned_rt::tasks::mpsc::channel::>(); + + // Spawn a task to send numbers to the channel + spawned_rt::tasks::spawn(async move { + for i in 1..=5 { + tx.send(Ok(i)).unwrap(); + } + }); + + spawn_listener( + summatory_handle.clone(), + message_builder, + ReceiverStream::new(rx), + ); + + // Wait for 1 second so the whole stream is processed + rt::sleep(Duration::from_secs(1)).await; + + let val = Summatory::get_value(&mut summatory_handle).await.unwrap(); + assert_eq!(val, 15); + }) +} + +#[test] +pub fn test_sum_numbers_from_broadcast_channel() { + let runtime = rt::Runtime::new().unwrap(); + runtime.block_on(async move { + let mut summatory_handle = Summatory::start(0); + let (tx, rx) = tokio::sync::broadcast::channel::(5); + + // Spawn a task to send numbers to the channel + spawned_rt::tasks::spawn(async move { + for i in 1u8..=5 { + tx.send(i).unwrap(); + } + }); + + spawn_listener( + summatory_handle.clone(), + message_builder, + BroadcastStream::new(rx), + ); + + // Wait for 1 second so the whole stream is processed + rt::sleep(Duration::from_secs(1)).await; + + let val = Summatory::get_value(&mut summatory_handle).await.unwrap(); + assert_eq!(val, 15); + }) +} + +#[test] +pub fn test_stream_cancellation() { + const RUNNING_TIME: u64 = 1000; + + let runtime = rt::Runtime::new().unwrap(); + runtime.block_on(async move { + let mut summatory_handle = Summatory::start(0); + let (tx, rx) = spawned_rt::tasks::mpsc::channel::>(); + + // Spawn a task to send numbers to the channel + spawned_rt::tasks::spawn(async move { + for i in 1..=5 { + tx.send(Ok(i)).unwrap(); + rt::sleep(Duration::from_millis(RUNNING_TIME / 4)).await; + } + }); + + let (_handle, cancellation_token) = spawn_listener( + summatory_handle.clone(), + message_builder, + ReceiverStream::new(rx), + ); + + // Wait for 1 second so the whole stream is processed + rt::sleep(Duration::from_millis(RUNNING_TIME)).await; + + cancellation_token.cancel(); + + // The reasoning for this assertion is that each message takes a quarter of the total time + // to be processed, so having a stream of 5 messages, the last one won't be processed. + // We could safely assume that it will get to process 4 messages, but in case of any extenal + // slowdown, it could process less. + let val = Summatory::get_value(&mut summatory_handle).await.unwrap(); + assert!(val > 0 && val < 15); + }) +} diff --git a/concurrency/src/threads/mod.rs b/concurrency/src/threads/mod.rs index 44e9dcd..193af89 100644 --- a/concurrency/src/threads/mod.rs +++ b/concurrency/src/threads/mod.rs @@ -3,6 +3,7 @@ mod gen_server; mod process; +mod stream; mod time; #[cfg(test)] @@ -10,4 +11,5 @@ mod timer_tests; pub use gen_server::{CallResponse, CastResponse, GenServer, GenServerHandle, GenServerInMsg}; pub use process::{send, Process, ProcessInfo}; +pub use stream::spawn_listener; pub use time::{send_after, send_interval}; diff --git a/concurrency/src/threads/stream.rs b/concurrency/src/threads/stream.rs new file mode 100644 index 0000000..a4fd749 --- /dev/null +++ b/concurrency/src/threads/stream.rs @@ -0,0 +1,17 @@ +use crate::threads::{GenServer, GenServerHandle}; + +use futures::Stream; + +/// Spawns a listener that listens to a stream and sends messages to a GenServer. +/// +/// Items sent through the stream are required to be wrapped in a Result type. +pub fn spawn_listener(_handle: GenServerHandle, _message_builder: F, _stream: S) +where + T: GenServer + 'static, + F: Fn(I) -> T::CastMsg + Send + 'static, + I: Send + 'static, + E: std::fmt::Debug + Send + 'static, + S: Unpin + Send + Stream> + 'static, +{ + unimplemented!("Unsupported function in threads mode") +} diff --git a/rt/Cargo.toml b/rt/Cargo.toml index b4fbbcf..b3317ac 100644 --- a/rt/Cargo.toml +++ b/rt/Cargo.toml @@ -6,9 +6,10 @@ edition = "2021" [dependencies] tokio = { version = "1", features = ["full"] } tokio-util = { version = "0.7.15" } +tokio-stream = { version = "0.1.17", features = ["sync"] } crossbeam = { version = "0.7.3" } tracing = { workspace = true } tracing-subscriber = { workspace = true } [lib] -path = "./src/lib.rs" \ No newline at end of file +path = "./src/lib.rs" diff --git a/rt/src/tasks/mod.rs b/rt/src/tasks/mod.rs index 7508d43..10de5fd 100644 --- a/rt/src/tasks/mod.rs +++ b/rt/src/tasks/mod.rs @@ -18,6 +18,7 @@ pub use crate::tasks::tokio::oneshot; pub use crate::tasks::tokio::sleep; pub use crate::tasks::tokio::CancellationToken; pub use crate::tasks::tokio::{spawn, spawn_blocking, JoinHandle, Runtime}; +pub use crate::tasks::tokio::{BroadcastStream, ReceiverStream}; use std::future::Future; pub fn run(future: F) -> F::Output { diff --git a/rt/src/tasks/tokio/mod.rs b/rt/src/tasks/tokio/mod.rs index aaf679d..6abf60d 100644 --- a/rt/src/tasks/tokio/mod.rs +++ b/rt/src/tasks/tokio/mod.rs @@ -7,4 +7,5 @@ pub use tokio::{ task::{spawn, spawn_blocking, JoinHandle}, time::sleep, }; +pub use tokio_stream::wrappers::{BroadcastStream, UnboundedReceiverStream as ReceiverStream}; pub use tokio_util::sync::CancellationToken;