diff --git a/Cargo.lock b/Cargo.lock index ad7c48480..4a8308d98 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4336,6 +4336,7 @@ dependencies = [ "magicblock-ledger", "magicblock-program", "rusqlite", + "serial_test", "solana-instruction 3.4.0", "solana-message 3.1.0", "solana-pubkey 3.0.0", diff --git a/magicblock-task-scheduler/Cargo.toml b/magicblock-task-scheduler/Cargo.toml index 0295a5f9d..fa0ede3d6 100644 --- a/magicblock-task-scheduler/Cargo.toml +++ b/magicblock-task-scheduler/Cargo.toml @@ -28,3 +28,6 @@ solana-transaction-error = { workspace = true } thiserror = { workspace = true } tokio = { workspace = true } tokio-util = { workspace = true, features = ["time"] } + +[dev-dependencies] +serial_test = { workspace = true } diff --git a/magicblock-task-scheduler/src/service.rs b/magicblock-task-scheduler/src/service.rs index aa2148551..1d30a15e0 100644 --- a/magicblock-task-scheduler/src/service.rs +++ b/magicblock-task-scheduler/src/service.rs @@ -9,7 +9,9 @@ use std::{ use futures_util::{future::poll_fn, FutureExt, StreamExt}; use magicblock_config::config::TaskSchedulerConfig; -use magicblock_core::link::transactions::ScheduledTasksRx; +use magicblock_core::{ + coordination_mode::CoordinationMode, link::transactions::ScheduledTasksRx, +}; use magicblock_ledger::LatestBlock; use magicblock_program::{ args::{CancelTaskRequest, ScheduleTaskRequest, TaskRequest}, @@ -134,6 +136,20 @@ impl TaskSchedulerService { pub async fn start( mut self, ) -> TaskSchedulerResult>> { + if self.is_primary_mode().await { + self.load_persisted_tasks().await?; + Ok(tokio::spawn(self.run())) + } else { + debug!("Task scheduler on standby mode does not start"); + Ok(tokio::spawn(async move { Ok(()) })) + } + } + + async fn load_persisted_tasks(&mut self) -> TaskSchedulerResult<()> { + self.task_queue.clear(); + self.task_queue_keys.clear(); + self.task_execution_retries.clear(); + // Reschedule all tasks that are due let tasks = self.db.get_tasks().await?; let now = chrono::Utc::now().timestamp_millis(); @@ -170,7 +186,7 @@ impl TaskSchedulerService { self.task_queue_keys.insert(task_id, key); } - Ok(tokio::spawn(self.run())) + Ok(()) } /// Main loop of the task scheduler. @@ -726,6 +742,20 @@ impl TaskSchedulerService { .unwrap_or(TASK_EXECUTION_RETRY_MAX_DELAY) .min(TASK_EXECUTION_RETRY_MAX_DELAY) } + + /// Waits until the coordination mode is not StartingUp. + /// Should be fast because task scheduler is started after the ledger replay completes. + async fn is_primary_mode(&self) -> bool { + let mut mode = CoordinationMode::current(); + while mode == CoordinationMode::StartingUp { + tokio::select! { + _ = self.token.cancelled() => return false, + _ = tokio::time::sleep(Duration::from_millis(100)) => {} + } + mode = CoordinationMode::current(); + } + mode == CoordinationMode::Primary + } } fn is_valid_task_interval(interval: i64) -> bool { @@ -755,12 +785,14 @@ fn is_retryable_task_execution_error(error: &TaskSchedulerError) -> bool { #[cfg(test)] mod tests { - use std::sync; - + use magicblock_core::coordination_mode::{ + switch_to_primary_mode, switch_to_replica_mode, + }; use magicblock_program::{ args::ScheduleTaskRequest, validator::generate_validator_authority_if_needed, }; + use serial_test::serial; use solana_pubkey::Pubkey; use tokio::{sync::mpsc, time::timeout}; @@ -780,7 +812,7 @@ mod tests { task_queue_keys: HashMap::new(), task_versions: HashMap::new(), task_execution_retries: HashMap::new(), - tx_counter: sync::Arc::new(AtomicU64::default()), + tx_counter: Arc::new(AtomicU64::default()), token: CancellationToken::new(), min_interval: Duration::from_millis(1000), failed_task_retention: Duration::from_secs(60), @@ -790,11 +822,13 @@ mod tests { } } + #[serial] #[test] fn test_first_execution_anchors_cadence_at_now() { assert_eq!(next_execution_millis(0, 50, 1_000), 1_000); } + #[serial] #[test] fn test_recurring_execution_preserves_fixed_rate_cadence() { let executed_at = next_execution_millis(1_000, 50, 1_090); @@ -804,14 +838,17 @@ mod tests { assert_eq!(delay, Duration::from_millis(10)); } + #[serial] #[test] fn test_overdue_execution_is_rescheduled_immediately() { assert_eq!(delay_until_millis(1_100, 1_150), Duration::from_millis(0)); } + #[serial] #[tokio::test] async fn test_schedule_invalid_tasks() { magicblock_core::logger::init_for_tests(); + switch_to_primary_mode(); generate_validator_authority_if_needed(); let (tx, rx) = mpsc::unbounded_channel(); @@ -859,9 +896,11 @@ mod tests { handle.abort(); } + #[serial] #[tokio::test] async fn test_remove_invalid_tasks_on_startup() { magicblock_core::logger::init_for_tests(); + switch_to_primary_mode(); let (_tx, rx) = mpsc::unbounded_channel(); let db = SchedulerDatabase::new(":memory:").unwrap(); @@ -909,9 +948,11 @@ mod tests { handle.abort(); } + #[serial] #[tokio::test] async fn test_completed_tasks_are_removed_on_startup() { magicblock_core::logger::init_for_tests(); + switch_to_primary_mode(); let (_tx, rx) = mpsc::unbounded_channel(); let db = SchedulerDatabase::new(":memory:").unwrap(); @@ -958,46 +999,11 @@ mod tests { handle.abort(); } + #[serial] #[tokio::test] async fn test_stale_crank_completion_does_not_mutate_replaced_task() { magicblock_core::logger::init_for_tests(); - - let (_tx, rx) = mpsc::unbounded_channel(); - let db = SchedulerDatabase::new(":memory:").unwrap(); - - db.insert_failed_scheduling(1, "schedule failed".to_string()) - .await - .unwrap(); - db.insert_failed_task(2, "task failed".to_string()) - .await - .unwrap(); - tokio::time::sleep(Duration::from_millis(2)).await; - - let mut service = test_service(db.clone(), rx); - service.failed_task_retention = Duration::from_millis(1); - service.failed_task_cleanup_interval = Duration::from_millis(5); - - let handle = service.start().await.unwrap(); - - timeout(Duration::from_secs(1), async move { - loop { - if db.get_failed_schedulings().await?.is_empty() - && db.get_failed_tasks().await?.is_empty() - { - return Ok::<_, TaskSchedulerError>(()); - } - tokio::time::sleep(Duration::from_millis(5)).await; - } - }) - .await - .unwrap() - .unwrap(); - handle.abort(); - } - - #[tokio::test] - async fn test_failed_records_are_cleaned_up_periodically() { - magicblock_core::logger::init_for_tests(); + switch_to_primary_mode(); let (_tx, rx) = mpsc::unbounded_channel(); let db = SchedulerDatabase::new(":memory:").unwrap(); @@ -1045,4 +1051,64 @@ mod tests { assert_eq!(queued.updated_at, replacement.updated_at); assert_eq!(queued.executions_left, replacement.executions_left); } + + #[serial] + #[tokio::test] + async fn test_task_scheduler_does_not_start_on_standby_mode() { + magicblock_core::logger::init_for_tests(); + switch_to_replica_mode(); + + let (_tx, rx) = mpsc::unbounded_channel(); + let db = SchedulerDatabase::new(":memory:").unwrap(); + let service = test_service(db.clone(), rx); + let handle = service.start().await.unwrap(); + + switch_to_primary_mode(); + + // Handle should join immediately because it's in standby mode + timeout(Duration::from_secs(1), handle) + .await + .unwrap() + .unwrap() + .unwrap(); + } + + #[serial] + #[tokio::test] + async fn test_failed_records_are_cleaned_up_periodically() { + magicblock_core::logger::init_for_tests(); + switch_to_primary_mode(); + + let (_tx, rx) = mpsc::unbounded_channel(); + let db = SchedulerDatabase::new(":memory:").unwrap(); + + db.insert_failed_scheduling(1, "schedule failed".to_string()) + .await + .unwrap(); + db.insert_failed_task(2, "task failed".to_string()) + .await + .unwrap(); + tokio::time::sleep(Duration::from_millis(2)).await; + + let mut service = test_service(db.clone(), rx); + service.failed_task_retention = Duration::from_millis(1); + service.failed_task_cleanup_interval = Duration::from_millis(5); + + let handle = service.start().await.unwrap(); + + timeout(Duration::from_secs(1), async move { + loop { + if db.get_failed_schedulings().await?.is_empty() + && db.get_failed_tasks().await?.is_empty() + { + return Ok::<_, TaskSchedulerError>(()); + } + tokio::time::sleep(Duration::from_millis(5)).await; + } + }) + .await + .unwrap() + .unwrap(); + handle.abort(); + } }