From 42b53bd4355ef6da4746c471d73c428cdae9fd64 Mon Sep 17 00:00:00 2001 From: Feiran Date: Fri, 27 Dec 2024 21:18:05 +0800 Subject: [PATCH] stream graceful shutdown --- Cargo.toml | 1 + examples/drain_and_stop/Cargo.toml | 14 ++++ examples/drain_and_stop/src/main.rs | 103 ++++++++++++++++++++++++++++ hypersync-client/src/lib.rs | 3 +- hypersync-client/src/stream.rs | 54 +++++++++++++-- 5 files changed, 170 insertions(+), 5 deletions(-) create mode 100644 examples/drain_and_stop/Cargo.toml create mode 100644 examples/drain_and_stop/src/main.rs diff --git a/Cargo.toml b/Cargo.toml index cdadb73..5ac3097 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,4 +11,5 @@ members = [ "examples/watch", "examples/reverse_wallet", "examples/call_watch", + "examples/drain_and_stop", ] diff --git a/examples/drain_and_stop/Cargo.toml b/examples/drain_and_stop/Cargo.toml new file mode 100644 index 0000000..445a930 --- /dev/null +++ b/examples/drain_and_stop/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "drain_and_stop" +version = "0.1.0" +edition = "2021" + +[dependencies] +hypersync-client = { path = "../../hypersync-client" } + +tokio = { version = "1", features = ["full"] } +serde_json = "1" +polars-arrow = { version = "0.42", features = [ + "compute_aggregate", +] } +env_logger = "0.4" \ No newline at end of file diff --git a/examples/drain_and_stop/src/main.rs b/examples/drain_and_stop/src/main.rs new file mode 100644 index 0000000..e0bab98 --- /dev/null +++ b/examples/drain_and_stop/src/main.rs @@ -0,0 +1,103 @@ +// Example of using the client to stream data and then draining and stopping the stream +// It has no practical use but it is meant to show how to use the client + +use std::sync::Arc; + +use hypersync_client::{Client, ClientConfig, ColumnMapping, DataType, StreamConfig}; + +#[tokio::main] +async fn main() { + env_logger::init().unwrap(); + + // create default client, uses eth mainnet + let client = Client::new(ClientConfig::default()).unwrap(); + + let query = serde_json::from_value(serde_json::json!( { + // start from block 10123123 and go to the end of the chain (we don't specify a toBlock). + "from_block": 10123123, + // The logs we want. We will also automatically get transactions and blocks relating to these logs (the query implicitly joins them). + "logs": [ + { + // We want All ERC20 transfers so no address filter and only a filter for the first topic + "topics": [ + ["0xddf252ad1be2c89b69c2b068fc378daa952ba7f163c4a11628f55a4df523b3ef"], + ] + } + ], + // Select the fields we are interested in, notice topics are selected as topic0,1,2,3 + "field_selection": { + "block": [ + "number", + ], + "log": [ + "data", + "topic0", + "topic1", + "topic2", + "topic3", + ] + } + })) + .unwrap(); + + println!("Starting the stream"); + + // Put the client inside Arc so we can use it for streaming + let client = Arc::new(client); + + let mut drained = vec![]; + + // Stream arrow data so we can average the erc20 transfer amounts in memory + // + // This will parallelize internal requests so we don't have to worry about pipelining/parallelizing make request -> handle response -> handle data loop + let mut receiver = client + .stream_arrow( + query, + StreamConfig { + // Pass the event signature for decoding + event_signature: Some( + "Transfer(address indexed from, address indexed to, uint amount)".to_owned(), + ), + column_mapping: Some(ColumnMapping { + decoded_log: [ + // Map the amount column to float so we can do aggregation on it + ("amount".to_owned(), DataType::Float64), + ] + .into_iter() + .collect(), + ..Default::default() + }), + ..Default::default() + }, + ) + .await + .unwrap(); + + let mut count = 0; + + // Receive the data in a loop + while let Some(res) = receiver.recv().await { + let res = res.unwrap(); + count += 1; + + println!( + "scanned up to block: {}, found {} blocks", + res.next_block, + res.data.blocks.len() + ); + + if res.next_block > 10129290 { + drained = receiver.drain_and_stop().await; + println!("Drained {} responses", drained.len()); + break; + } + } + + count += drained.len(); + + for data in drained { + println!("data: {:?}", data.unwrap().next_block); + } + + println!("response count: {}", count); +} diff --git a/hypersync-client/src/lib.rs b/hypersync-client/src/lib.rs index c9287c5..49e905f 100644 --- a/hypersync-client/src/lib.rs +++ b/hypersync-client/src/lib.rs @@ -30,6 +30,7 @@ pub use hypersync_schema as schema; use parse_response::parse_query_response; use simple_types::Event; +use stream::ArrowStream; use tokio::sync::mpsc; use types::{EventResponse, ResponseData}; use url::Url; @@ -525,7 +526,7 @@ impl Client { self: Arc, query: Query, config: StreamConfig, - ) -> Result>> { + ) -> Result { stream::stream_arrow(self, query, config).await } diff --git a/hypersync-client/src/stream.rs b/hypersync-client/src/stream.rs index 958e1d3..fb5b304 100644 --- a/hypersync-client/src/stream.rs +++ b/hypersync-client/src/stream.rs @@ -14,8 +14,9 @@ use polars_arrow::{ datatypes::ArrowDataType, record_batch::RecordBatch, }; -use tokio::sync::mpsc; use tokio::task::JoinSet; +use tokio::{sync::mpsc, task::JoinHandle}; +use tokio_util::sync::CancellationToken; use crate::{ config::HexOutput, @@ -25,11 +26,41 @@ use crate::{ ArrowBatch, ArrowResponseData, StreamConfig, }; +pub struct ArrowStream { + // Used to cancel the stream + cancel_token: CancellationToken, + // Join handle for waiting for the stream to finish + handle: JoinHandle<()>, + // Receiver for the stream + rx: mpsc::Receiver>, +} + +impl ArrowStream { + pub async fn recv(&mut self) -> Option> { + self.rx.recv().await + } + + /// Signals all tasks to stop via cancellation, then waits for them + /// to finish and drains any leftover items. + pub async fn drain_and_stop(self) -> Vec> { + self.cancel_token.cancel(); + + let mut drained = Vec::new(); + let mut rx = self.rx; + + while let Some(item) = rx.recv().await { + drained.push(item); + } + + drained + } +} + pub async fn stream_arrow( client: Arc, query: Query, config: StreamConfig, -) -> Result>> { +) -> Result { let concurrency = config.concurrency.unwrap_or(10); let batch_size = config.batch_size.unwrap_or(1000); let max_batch_size = config.max_batch_size.unwrap_or(200_000); @@ -42,12 +73,19 @@ pub async fn stream_arrow( let (tx, rx) = mpsc::channel(concurrency * 2); + let cancel_token = CancellationToken::new(); + let cancel_token_clone = cancel_token.clone(); + let to_block = match query.to_block { Some(to_block) => to_block, None => client.get_height().await.context("get height")?, }; - tokio::spawn(async move { + let handle = tokio::spawn(async move { + if cancel_token.is_cancelled() { + return; + } + let mut query = query; if !reverse { @@ -97,6 +135,10 @@ pub async fn stream_arrow( let mut next_req_idx = 0; while futs.peek().is_some() { + if cancel_token.is_cancelled() { + break; + } + while let Some(res) = set.try_join_next() { let (generation, req_idx, resps) = res.unwrap(); queue.insert(req_idx, (generation, resps)); @@ -208,7 +250,11 @@ pub async fn stream_arrow( } }); - Ok(rx) + Ok(ArrowStream { + cancel_token: cancel_token_clone, + handle, + rx, + }) } fn count_rows(batches: &[ArrowBatch]) -> usize {