diff --git a/lading/src/generator/grpc.rs b/lading/src/generator/grpc.rs index 822e9c06b..57c7adaa8 100644 --- a/lading/src/generator/grpc.rs +++ b/lading/src/generator/grpc.rs @@ -13,7 +13,7 @@ //! Additional metrics may be emitted by this generator's [throttle]. //! -use std::{convert::TryFrom, num::NonZeroU32, time::Duration}; +use std::{convert::TryFrom, num::NonZeroU16, num::NonZeroU32, sync::Arc, time::Duration}; use byte_unit::Byte; use bytes::{Buf, BufMut, Bytes}; @@ -22,9 +22,11 @@ use http::{ uri::{self, PathAndQuery}, }; use metrics::counter; +use once_cell::sync::OnceCell; use rand::SeedableRng; use rand::rngs::StdRng; use serde::{Deserialize, Serialize}; +use tokio::sync::Semaphore; use tonic::{ Request, Response, Status, client, codec::{DecodeBuf, Decoder, EncodeBuf, Encoder}, @@ -36,9 +38,12 @@ use lading_payload::block; use super::General; use crate::generator::common::{ - BlockThrottle, MetricsBuilder, ThrottleConfig, ThrottleConversionError, create_throttle, + BlockThrottle, ConcurrencyStrategy, MetricsBuilder, ThrottleConfig, ThrottleConversionError, + create_throttle, }; +static CONNECTION_SEMAPHORE: OnceCell = OnceCell::new(); + /// Errors produced by [`Grpc`] #[derive(thiserror::Error, Debug)] pub enum Error { @@ -171,12 +176,12 @@ impl Decoder for CountingDecoder { /// This generator is able to connect to targets via gRPC. #[derive(Debug)] pub struct Grpc { - config: Config, target_uri: Uri, rpc_path: PathAndQuery, + concurrency: ConcurrencyStrategy, shutdown: lading_signal::Watcher, throttle: BlockThrottle, - block_cache: block::Cache, + block_cache: Arc, metric_labels: Vec<(String, String)>, } @@ -223,8 +228,15 @@ impl Grpc { )?, }; + let concurrency = + ConcurrencyStrategy::new(NonZeroU16::new(config.parallel_connections), false); + + CONNECTION_SEMAPHORE + .set(Semaphore::new(concurrency.connection_count() as usize)) + .expect("failed to set semaphore"); + let target_uri = - uri::Uri::try_from(config.target_uri.clone()).expect("target_uri must be valid"); + uri::Uri::try_from(config.target_uri).expect("target_uri must be valid"); let rpc_path = target_uri .path_and_query() .cloned() @@ -232,9 +244,9 @@ impl Grpc { Ok(Self { target_uri, rpc_path, - config, + concurrency, shutdown, - block_cache, + block_cache: Arc::new(block_cache), throttle, metric_labels: labels, }) @@ -247,7 +259,6 @@ impl Grpc { let uri = Uri::from_parts(parts).expect("failed to convert parts into uri"); let endpoint = transport::Endpoint::new(uri)?; - let endpoint = endpoint.concurrency_limit(self.config.parallel_connections as usize); let endpoint = endpoint.connect_timeout(Duration::from_secs(1)); let conn = endpoint.connect().await?; let conn = client::Grpc::new(conn); @@ -284,7 +295,7 @@ impl Grpc { /// /// Function will panic if underlying byte capacity is not available. pub async fn spin(mut self) -> Result<(), Error> { - let mut client = loop { + let client = loop { match self.connect().await { Ok(c) => break c, Err(source) => { @@ -300,6 +311,7 @@ impl Grpc { let mut handle = self.block_cache.handle(); let rpc_path = self.rpc_path; + let labels = self.metric_labels; let shutdown_wait = self.shutdown.recv(); tokio::pin!(shutdown_wait); @@ -308,43 +320,70 @@ impl Grpc { result = self.throttle.wait_for_block(&self.block_cache, &handle) => { let _ = result; let block = self.block_cache.advance(&mut handle); + let block_bytes = Bytes::copy_from_slice(&block.bytes); let block_length = block.bytes.len(); - counter!("requests_sent", &self.metric_labels).increment(1); - let res = Self::req( - &mut client, - rpc_path.clone(), - Bytes::copy_from_slice(&block.bytes), - ) - .await; - - match res { - Ok(res) => { - counter!("bytes_written", &self.metric_labels).increment(block_length as u64); - if let Some(data_points) = block.metadata.data_points { - counter!("data_points_transmitted", &self.metric_labels).increment(data_points); + let data_points = block.metadata.data_points; + + let mut task_client = client.clone(); + let task_rpc_path = rpc_path.clone(); + let task_labels = labels.clone(); + let target_uri = self.target_uri.clone(); + + let permit = CONNECTION_SEMAPHORE + .get() + .expect("connection semaphore not initialized") + .acquire() + .await + .expect("connection semaphore closed"); + tokio::spawn(async move { + counter!("requests_sent", &task_labels).increment(1); + let res = Self::req( + &mut task_client, + task_rpc_path.clone(), + block_bytes, + ) + .await; + + match res { + Ok(res) => { + counter!("bytes_written", &task_labels) + .increment(block_length as u64); + if let Some(dp) = data_points { + counter!("data_points_transmitted", &task_labels) + .increment(dp); + } + counter!("request_ok", &task_labels).increment(1); + counter!("response_bytes", &task_labels) + .increment(res.into_inner() as u64); + } + Err(source) => { + error!( + "Failed to make RPC request to {endpoint}{path}: {source}", + endpoint = target_uri, + path = task_rpc_path + ); + let mut error_labels = task_labels.clone(); + error_labels + .push(("error".to_string(), source.to_string())); + counter!("request_failure", &error_labels).increment(1); } - counter!("request_ok", &self.metric_labels).increment(1); - counter!("response_bytes", &self.metric_labels).increment(res.into_inner() as u64); - } - Err(source) => { - error!( - "Failed to make RPC request to {endpoint}{path}: {source}", - endpoint = self.target_uri, - path = rpc_path - ); - let mut error_labels = self.metric_labels.clone(); - error_labels.push(("error".to_string(), source.to_string())); - counter!("request_failure", &error_labels).increment(1); } - } + drop(permit); + }); }, () = &mut shutdown_wait => { info!("shutdown signal received"); - break; + // Acquire all permits to ensure in-flight tasks complete + // before returning. + let _semaphore = CONNECTION_SEMAPHORE + .get() + .expect("connection semaphore not initialized") + .acquire_many(u32::from(self.concurrency.connection_count())) + .await + .expect("connection semaphore closed"); + return Ok(()); }, } } - - Ok(()) } }