diff --git a/concurrency/src/tasks/gen_server.rs b/concurrency/src/tasks/gen_server.rs index 39ec0b5..6c15ce0 100644 --- a/concurrency/src/tasks/gen_server.rs +++ b/concurrency/src/tasks/gen_server.rs @@ -1,20 +1,22 @@ //! GenServer trait and structs to create an abstraction similar to Erlang gen_server. //! See examples/name_server for a usage example. use futures::future::FutureExt as _; -use spawned_rt::tasks::{self as rt, mpsc, oneshot}; +use spawned_rt::tasks::{self as rt, mpsc, oneshot, CancellationToken}; use std::{fmt::Debug, future::Future, panic::AssertUnwindSafe}; use crate::error::GenServerError; -#[derive(Debug)] pub struct GenServerHandle { pub tx: mpsc::Sender>, + /// Cancellation token to stop the GenServer + cancellation_token: CancellationToken, } impl Clone for GenServerHandle { fn clone(&self) -> Self { Self { tx: self.tx.clone(), + cancellation_token: self.cancellation_token.clone(), } } } @@ -22,7 +24,11 @@ impl Clone for GenServerHandle { impl GenServerHandle { pub(crate) fn new(initial_state: G::State) -> Self { let (tx, mut rx) = mpsc::channel::>(); - let handle = GenServerHandle { tx }; + let cancellation_token = CancellationToken::new(); + let handle = GenServerHandle { + tx, + cancellation_token, + }; let mut gen_server: G = GenServer::new(); let handle_clone = handle.clone(); // Ignore the JoinHandle for now. Maybe we'll use it in the future @@ -40,7 +46,11 @@ impl GenServerHandle { pub(crate) fn new_blocking(initial_state: G::State) -> Self { let (tx, mut rx) = mpsc::channel::>(); - let handle = GenServerHandle { tx }; + let cancellation_token = CancellationToken::new(); + let handle = GenServerHandle { + tx, + cancellation_token, + }; let mut gen_server: G = GenServer::new(); let handle_clone = handle.clone(); // Ignore the JoinHandle for now. Maybe we'll use it in the future @@ -79,6 +89,10 @@ impl GenServerHandle { .send(GenServerInMsg::Cast { message }) .map_err(|_error| GenServerError::Server) } + + pub fn cancellation_token(&self) -> CancellationToken { + self.cancellation_token.clone() + } } pub enum GenServerInMsg { @@ -168,12 +182,16 @@ where async { loop { let (new_state, cont) = self.receive(handle, rx, state).await?; + state = new_state; if !cont { break; } - state = new_state; } tracing::trace!("Stopping GenServer"); + handle.cancellation_token().cancel(); + if let Err(err) = self.teardown(handle, state).await { + tracing::error!("Error during teardown: {err:?}"); + } Ok(()) } } @@ -269,6 +287,17 @@ where ) -> impl Future> + Send { async { CastResponse::Unused } } + + /// Teardown function. It's called after the stop message is received. + /// It can be overrided on implementations in case final steps are required, + /// like closing streams, stopping timers, etc. + fn teardown( + &mut self, + _handle: &GenServerHandle, + _state: Self::State, + ) -> impl Future> + Send { + async { Ok(()) } + } } #[cfg(test)] diff --git a/concurrency/src/tasks/stream.rs b/concurrency/src/tasks/stream.rs index bb3df3b..4c4e844 100644 --- a/concurrency/src/tasks/stream.rs +++ b/concurrency/src/tasks/stream.rs @@ -1,6 +1,6 @@ use crate::tasks::{GenServer, GenServerHandle}; use futures::{future::select, Stream, StreamExt}; -use spawned_rt::tasks::{CancellationToken, JoinHandle}; +use spawned_rt::tasks::JoinHandle; /// Spawns a listener that listens to a stream and sends messages to a GenServer. /// @@ -12,7 +12,7 @@ pub fn spawn_listener( mut handle: GenServerHandle, message_builder: F, mut stream: S, -) -> (JoinHandle<()>, CancellationToken) +) -> JoinHandle<()> where T: GenServer + 'static, F: Fn(I) -> T::CastMsg + Send + 'static + std::marker::Sync, @@ -20,11 +20,10 @@ where E: std::fmt::Debug + Send, S: Unpin + Send + Stream> + 'static, { - let cancelation_token = CancellationToken::new(); - let cloned_token = cancelation_token.clone(); + let cancelation_token = handle.cancellation_token(); let join_handle = spawned_rt::tasks::spawn(async move { let result = select( - Box::pin(cloned_token.cancelled()), + Box::pin(cancelation_token.cancelled()), Box::pin(async { loop { match stream.next().await { @@ -49,9 +48,9 @@ where ) .await; match result { - futures::future::Either::Left(_) => tracing::trace!("Listener cancelled"), + futures::future::Either::Left(_) => tracing::trace!("GenServer stopped"), futures::future::Either::Right(_) => (), // Stream finished or errored out } }); - (join_handle, cancelation_token) + join_handle } diff --git a/concurrency/src/tasks/stream_tests.rs b/concurrency/src/tasks/stream_tests.rs index 5236363..e96c7e1 100644 --- a/concurrency/src/tasks/stream_tests.rs +++ b/concurrency/src/tasks/stream_tests.rs @@ -3,7 +3,7 @@ use std::time::Duration; use spawned_rt::tasks::{self as rt, BroadcastStream, ReceiverStream}; use crate::tasks::{ - stream::spawn_listener, CallResponse, CastResponse, GenServer, GenServerHandle, + send_after, stream::spawn_listener, CallResponse, CastResponse, GenServer, GenServerHandle, }; type SummatoryHandle = GenServerHandle; @@ -11,10 +11,12 @@ type SummatoryHandle = GenServerHandle; struct Summatory; type SummatoryState = u16; +type SummatoryOutMessage = SummatoryState; #[derive(Clone)] -struct UpdateSumatory { - added_value: u16, +enum SummatoryCastMessage { + Add(u16), + Stop, } impl Summatory { @@ -25,8 +27,8 @@ impl Summatory { impl GenServer for Summatory { type CallMsg = (); // We only handle one type of call, so there is no need for a specific message type. - type CastMsg = UpdateSumatory; - type OutMsg = SummatoryState; + type CastMsg = SummatoryCastMessage; + type OutMsg = SummatoryOutMessage; type State = SummatoryState; type Error = (); @@ -40,8 +42,13 @@ impl GenServer for Summatory { _handle: &GenServerHandle, state: Self::State, ) -> CastResponse { - let new_state = state + message.added_value; - CastResponse::NoReply(new_state) + match message { + SummatoryCastMessage::Add(val) => { + let new_state = state + val; + CastResponse::NoReply(new_state) + } + SummatoryCastMessage::Stop => CastResponse::Stop, + } } async fn handle_call( @@ -56,11 +63,9 @@ impl GenServer for Summatory { } // In this example, the stream sends u8 values, which are converted to the type -// supported by the GenServer (UpdateSumatory / u16). -fn message_builder(value: u8) -> UpdateSumatory { - UpdateSumatory { - added_value: value as u16, - } +// supported by the GenServer (SummatoryCastMessage / u16). +fn message_builder(value: u8) -> SummatoryCastMessage { + SummatoryCastMessage::Add(value.into()) } #[test] @@ -153,22 +158,34 @@ pub fn test_stream_cancellation() { } }); - let (_handle, cancellation_token) = spawn_listener( + let listener_handle = spawn_listener( summatory_handle.clone(), message_builder, ReceiverStream::new(rx), ); - // Wait for 1 second so the whole stream is processed - rt::sleep(Duration::from_millis(RUNNING_TIME)).await; + // Start a timer to stop the stream after a certain time + let summatory_handle_clone = summatory_handle.clone(); + let _ = send_after( + Duration::from_millis(RUNNING_TIME + 10), + summatory_handle_clone, + SummatoryCastMessage::Stop, + ); - cancellation_token.cancel(); + // Just before the stream is cancelled we retrieve the current value. + rt::sleep(Duration::from_millis(RUNNING_TIME)).await; + let val = Summatory::get_value(&mut summatory_handle).await.unwrap(); // The reasoning for this assertion is that each message takes a quarter of the total time // to be processed, so having a stream of 5 messages, the last one won't be processed. // We could safely assume that it will get to process 4 messages, but in case of any extenal // slowdown, it could process less. - let val = Summatory::get_value(&mut summatory_handle).await.unwrap(); - assert!(val > 0 && val < 15); + assert!((1..=10).contains(&val)); + + assert!(listener_handle.await.is_ok()); + + // Finnally, we check that the server is stopped, by getting an error when trying to call it. + rt::sleep(Duration::from_millis(10)).await; + assert!(Summatory::get_value(&mut summatory_handle).await.is_err()); }) } diff --git a/concurrency/src/tasks/time.rs b/concurrency/src/tasks/time.rs index f26118b..619e553 100644 --- a/concurrency/src/tasks/time.rs +++ b/concurrency/src/tasks/time.rs @@ -22,9 +22,16 @@ where { let cancellation_token = CancellationToken::new(); let cloned_token = cancellation_token.clone(); + let gen_server_cancellation_token = handle.cancellation_token(); let join_handle = rt::spawn(async move { - let _ = select( + // Timer action is ignored if it was either cancelled or the associated GenServer is no longer running. + let cancel_conditions = select( Box::pin(cloned_token.cancelled()), + Box::pin(gen_server_cancellation_token.cancelled()), + ); + + let _ = select( + cancel_conditions, Box::pin(async { rt::sleep(period).await; let _ = handle.cast(message.clone()).await; @@ -49,10 +56,17 @@ where { let cancellation_token = CancellationToken::new(); let cloned_token = cancellation_token.clone(); + let gen_server_cancellation_token = handle.cancellation_token(); let join_handle = rt::spawn(async move { loop { - let result = select( + // Timer action is ignored if it was either cancelled or the associated GenServer is no longer running. + let cancel_conditions = select( Box::pin(cloned_token.cancelled()), + Box::pin(gen_server_cancellation_token.cancelled()), + ); + + let result = select( + Box::pin(cancel_conditions), Box::pin(async { rt::sleep(period).await; let _ = handle.cast(message.clone()).await; diff --git a/concurrency/src/tasks/timer_tests.rs b/concurrency/src/tasks/timer_tests.rs index d805c82..297a45c 100644 --- a/concurrency/src/tasks/timer_tests.rs +++ b/concurrency/src/tasks/timer_tests.rs @@ -149,6 +149,7 @@ enum DelayedCastMessage { #[derive(Clone)] enum DelayedCallMessage { GetCount, + Stop, } #[derive(PartialEq, Debug)] @@ -165,6 +166,10 @@ impl Delayed { .await .map_err(|_| ()) } + + pub async fn stop(server: &mut DelayedHandle) -> Result { + server.call(DelayedCallMessage::Stop).await.map_err(|_| ()) + } } impl GenServer for Delayed { @@ -180,12 +185,17 @@ impl GenServer for Delayed { async fn handle_call( &mut self, - _message: Self::CallMsg, + message: Self::CallMsg, _handle: &DelayedHandle, state: Self::State, ) -> CallResponse { - let count = state.count; - CallResponse::Reply(state, DelayedOutMessage::Count(count)) + match message { + DelayedCallMessage::GetCount => { + let count = state.count; + CallResponse::Reply(state, DelayedOutMessage::Count(count)) + } + DelayedCallMessage::Stop => CallResponse::Stop(DelayedOutMessage::Count(state.count)), + } } async fn handle_cast( @@ -246,3 +256,44 @@ pub fn test_send_after_and_cancellation() { assert_eq!(DelayedOutMessage::Count(1), count2); }); } + +#[test] +pub fn test_send_after_gen_server_teardown() { + let runtime = rt::Runtime::new().unwrap(); + runtime.block_on(async move { + // Start a Delayed + let mut repeater = Delayed::start(DelayedState { count: 0 }); + + // Set a just once timed message + let _ = send_after( + Duration::from_millis(100), + repeater.clone(), + DelayedCastMessage::Inc, + ); + + // Wait for 200 milliseconds + rt::sleep(Duration::from_millis(200)).await; + + // Check count + let count = Delayed::get_count(&mut repeater).await.unwrap(); + + // Only one message (no repetition) + assert_eq!(DelayedOutMessage::Count(1), count); + + // New timer + let _ = send_after( + Duration::from_millis(100), + repeater.clone(), + DelayedCastMessage::Inc, + ); + + // Stop the GenServer before timeout + let count2 = Delayed::stop(&mut repeater).await.unwrap(); + + // Wait another 200 milliseconds + rt::sleep(Duration::from_millis(200)).await; + + // As timer was cancelled, count should remain at 1 + assert_eq!(DelayedOutMessage::Count(1), count2); + }); +}