From 87b86806c5602d3ac2d86be0c1ba5b817c6a4070 Mon Sep 17 00:00:00 2001 From: E John Feig Date: Thu, 21 May 2026 15:13:44 -0700 Subject: [PATCH] feat: Implement Least Request Load Balancing Policy (gRFC A48) Implements the "all weights equal" Least Request Load Balancing policy in gRPC-Rust, in compliance with gRFC A48. The Least Request policy improves tail latencies in heterogeneous environments by tracking active request counts per endpoint and directing new requests to the backend with the lowest load. Detailed Changes: 1. Core Load Balancing Policy (`least_request.rs`): - Defined `LeastRequestLoadBalancingConfig` to parse and validate the `choiceCount` config parameter (default = 2, clamped from 2 to 10). - Implemented `LeastRequestBuilder` registering policy name `least_request_experimental`. - Implemented `LeastRequestPolicy` managing endpoint-level connections via `ChildManager` children delegating to `pick_first`. - Maintained a persistent mapping of weak subchannel references to active request counters (`subchannel_counters`) so that outstanding request metrics survive picker updates and name re-resolutions. - Implemented `LeastRequestPicker` utilizing a random sampling selection algorithm over `choice_count` subchannels. 2. Active Request Cancellation Safety: - Identified and resolved a request counter leak bug where async task cancellations during `dyn_invoke.await` dropped the `Pick` closure without calling it. - Implemented a custom, defusable `ActiveRequestGuard` using an `AtomicBool` inside `LeastRequestPicker::pick`. The guard guarantees that the active request count is decremented upon drop if the picker's `on_complete` callback is never invoked. 3. Channel & Service Config Integration: - Registered the builder with the global LB registry in `Channel::new` inside `channel.rs`. - Added `CallbackRecvStream` wrapping the stream in the channel's `Invoke` implementation to trigger `on_complete` callbacks when client streams are completed or dropped. - Added `LeastRequest` variant to `LbPolicyType` enum in `service_config.rs`. - Mapped `LbPolicyType::LeastRequest` configuration inside `ResolverChannelController::update` in `channel.rs`. 4. Test Additions & Verification: - Added comprehensive unit tests in `least_request.rs` covering configuration parsing/clamping/validation, least request selection, tie-breaking, fewer subchannels than choice count, and cancellation drop-guard safety. - Modified the `InMemoryResolver` in `inmemory/mod.rs` to dynamically set the `LeastRequest` load-balancing policy based on target URI path prefixes. - Wrote a robust E2E integration test `test_in_memory_least_request_load_balancing` in `inmemory/mod.rs` verifying dynamic load balancing across multiple in-memory backends concurrently. --- grpc/src/client/channel.rs | 69 +- .../client/load_balancing/least_request.rs | 647 ++++++++++++++++++ grpc/src/client/load_balancing/mod.rs | 1 + grpc/src/client/service_config.rs | 1 + grpc/src/inmemory/mod.rs | 112 ++- 5 files changed, 817 insertions(+), 13 deletions(-) create mode 100644 grpc/src/client/load_balancing/least_request.rs diff --git a/grpc/src/client/channel.rs b/grpc/src/client/channel.rs index c93697fc3..cd46d6a98 100644 --- a/grpc/src/client/channel.rs +++ b/grpc/src/client/channel.rs @@ -56,6 +56,7 @@ use crate::client::load_balancing::Picker; use crate::client::load_balancing::QueuingPicker; use crate::client::load_balancing::WorkScheduler; use crate::client::load_balancing::graceful_switch::GracefulSwitchPolicy; +use crate::client::load_balancing::least_request; use crate::client::load_balancing::pick_first; use crate::client::load_balancing::round_robin; use crate::client::load_balancing::subchannel::Subchannel; @@ -174,6 +175,7 @@ impl Channel { { pick_first::reg(); round_robin::reg(); + least_request::reg(); dns::reg(); #[cfg(unix)] name_resolution::unix::reg(); @@ -389,11 +391,20 @@ impl Invoke for Arc { "channel has been closed", )); }; - let result = &state.picker.pick(&headers); + let result = state.picker.pick(&headers); match result { - PickResult::Pick(pr) => { + PickResult::Pick(mut pr) => { if let Some(sc) = pr.subchannel.downcast_ref::() { - return sc.dyn_invoke(headers, options.clone()).await; + let (tx, rx) = sc.dyn_invoke(headers, options.clone()).await; + let rx = if let Some(on_complete) = pr.on_complete.take() { + Box::new(CallbackRecvStream { + delegate: rx, + on_complete: Some(on_complete), + }) as Box + } else { + rx + }; + return (tx, rx); } else { panic!( "picked subchannel is not an implementation provided by the channel" @@ -404,7 +415,7 @@ impl Invoke for Arc { // Continue and retry the RPC with the next picker. } PickResult::Fail(status) => { - return FailingRecvStream::new_stream_pair(status.clone()); + return FailingRecvStream::new_stream_pair(status); } PickResult::Drop(status) => { todo!("dropped pick: {:?}", status); @@ -420,6 +431,39 @@ impl Drop for ActiveChannel { } } +struct CallbackRecvStream { + delegate: Box, + on_complete: Option>, +} + +#[tonic::async_trait] +impl DynRecvStream for CallbackRecvStream { + async fn dyn_recv( + &mut self, + msg: &mut dyn crate::core::RecvMessage, + ) -> crate::client::ResponseStreamItem { + let item = self.delegate.dyn_recv(msg).await; + if matches!( + item, + crate::client::ResponseStreamItem::Trailers(_) + | crate::client::ResponseStreamItem::StreamClosed + ) { + if let Some(cb) = self.on_complete.take() { + cb(); + } + } + item + } +} + +impl Drop for CallbackRecvStream { + fn drop(&mut self) { + if let Some(cb) = self.on_complete.take() { + cb(); + } + } +} + struct ResolverWorkScheduler { wqtx: WorkQueueTx, } @@ -474,13 +518,22 @@ impl ResolverChannelController { impl name_resolution::ChannelController for ResolverChannelController { fn update(&mut self, update: ResolverUpdate) -> Result<(), String> { - let json_config = if let Ok(Some(service_config)) = update.service_config.as_ref() - && service_config + let json_config = if let Ok(Some(service_config)) = update.service_config.as_ref() { + if service_config .load_balancing_policy .as_ref() .is_some_and(|p| *p == LbPolicyType::RoundRobin) - { - json!([{round_robin::POLICY_NAME: {}}]) + { + json!([{round_robin::POLICY_NAME: {}}]) + } else if service_config + .load_balancing_policy + .as_ref() + .is_some_and(|p| *p == LbPolicyType::LeastRequest) + { + json!([{least_request::POLICY_NAME: {}}]) + } else { + json!([{pick_first::POLICY_NAME: {"shuffleAddressList": true, "unknown_field": false}}]) + } } else { json!([{pick_first::POLICY_NAME: {"shuffleAddressList": true, "unknown_field": false}}]) }; diff --git a/grpc/src/client/load_balancing/least_request.rs b/grpc/src/client/load_balancing/least_request.rs new file mode 100644 index 000000000..4be24faf9 --- /dev/null +++ b/grpc/src/client/load_balancing/least_request.rs @@ -0,0 +1,647 @@ +/* + * + * Copyright 2026 gRPC authors. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + * + */ + +use std::collections::HashMap; +use std::fmt::Debug; +use std::sync::Arc; +use std::sync::Once; +use std::sync::atomic::AtomicUsize; +use std::sync::atomic::Ordering; + +use crate::client::ConnectivityState; +use crate::client::load_balancing::ChannelController; +use crate::client::load_balancing::DynLbPolicyBuilder; +use crate::client::load_balancing::FailingPicker; +use crate::client::load_balancing::GLOBAL_LB_REGISTRY; +use crate::client::load_balancing::LbPolicy; +use crate::client::load_balancing::LbPolicyBuilder; +use crate::client::load_balancing::LbPolicyOptions; +use crate::client::load_balancing::LbState; +use crate::client::load_balancing::ParsedJsonLbConfig; +use crate::client::load_balancing::PickResult; +use crate::client::load_balancing::Picker; +use crate::client::load_balancing::Subchannel; +use crate::client::load_balancing::SubchannelState; +use crate::client::load_balancing::child_manager::ChildManager; +use crate::client::load_balancing::child_manager::ChildUpdate; +use crate::client::load_balancing::pick_first; +use crate::client::load_balancing::subchannel::WeakSubchannel; +use crate::client::name_resolution::Endpoint; +use crate::client::name_resolution::ResolverUpdate; +use crate::core::RequestHeaders; + +pub(crate) static POLICY_NAME: &str = "least_request_experimental"; +static START: Once = Once::new(); + +#[derive(serde::Deserialize, Debug, Clone, PartialEq)] +#[serde(rename_all = "camelCase")] +pub(crate) struct LeastRequestLoadBalancingConfig { + #[serde(default = "default_choice_count")] + pub choice_count: u32, +} + +fn default_choice_count() -> u32 { + 2 +} + +#[derive(Debug)] +pub(crate) struct LeastRequestBuilder {} + +impl LbPolicyBuilder for LeastRequestBuilder { + type LbPolicy = LeastRequestPolicy; + + fn build(&self, options: LbPolicyOptions) -> Self::LbPolicy { + let child_manager = ChildManager::new(options.runtime, options.work_scheduler); + LeastRequestPolicy::new( + child_manager, + GLOBAL_LB_REGISTRY + .get_policy(pick_first::POLICY_NAME) + .unwrap(), + ) + } + + fn name(&self) -> &'static str { + POLICY_NAME + } + + fn parse_config( + &self, + config: &ParsedJsonLbConfig, + ) -> Result::LbConfig>, String> { + let parsed: LeastRequestLoadBalancingConfig = config + .convert_to() + .map_err(|e| format!("failed to parse least_request config: {e}"))?; + + if parsed.choice_count < 2 { + return Err("choice_count must be at least 2".to_string()); + } + + let choice_count = parsed.choice_count.min(10); + Ok(Some(LeastRequestLoadBalancingConfig { choice_count })) + } +} + +#[derive(Debug)] +pub(crate) struct LeastRequestPolicy { + child_manager: ChildManager, + pick_first_builder: Arc, + choice_count: u32, + subchannel_counters: HashMap>, +} + +impl LeastRequestPolicy { + fn new( + child_manager: ChildManager, + pick_first_builder: Arc, + ) -> Self { + Self { + child_manager, + pick_first_builder, + choice_count: 2, + subchannel_counters: HashMap::new(), + } + } + + // Sets the policy's state to TRANSIENT_FAILURE with a picker returning the + // error string provided, then requests re-resolution from the channel. + fn move_to_transient_failure( + &mut self, + error: String, + channel_controller: &mut dyn ChannelController, + ) { + channel_controller.update_picker(LbState { + connectivity_state: ConnectivityState::TransientFailure, + picker: Arc::new(FailingPicker { error }), + }); + channel_controller.request_resolution(); + } + + // Sends an aggregate picker based on states of children. + fn update_picker(&mut self, channel_controller: &mut dyn ChannelController) { + if !self.child_manager.child_updated() { + return; + } + let aggregate_state = self.child_manager.aggregate_states(); + + if aggregate_state == ConnectivityState::Ready { + let mut ready_subchannels = Vec::new(); + for child in self.child_manager.children() { + if child.state.connectivity_state == ConnectivityState::Ready { + if let PickResult::Pick(pick) = child + .state + .picker + .pick(&crate::core::RequestHeaders::default()) + { + let weak = WeakSubchannel::new(&pick.subchannel); + let counter = self + .subchannel_counters + .entry(weak) + .or_insert_with(|| Arc::new(AtomicUsize::new(0))) + .clone(); + ready_subchannels.push(SubchannelWithCounter { + subchannel: pick.subchannel.clone(), + active_requests: counter, + }); + } + } + } + + // Clean up stale counters + self.subchannel_counters + .retain(|weak, _| weak.upgrade().is_some()); + + let picker_update = LbState { + connectivity_state: aggregate_state, + picker: Arc::new(LeastRequestPicker { + subchannels: ready_subchannels, + choice_count: self.choice_count as usize, + }), + }; + channel_controller.update_picker(picker_update); + } else { + // Forward the child picker for non-ready aggregate state + let picker = self + .child_manager + .children() + .find(|cs| cs.state.connectivity_state == aggregate_state) + .map(|cs| cs.state.picker.clone()) + .unwrap_or_else(|| { + Arc::new(crate::client::load_balancing::QueuingPicker) as Arc + }); + + channel_controller.update_picker(LbState { + connectivity_state: aggregate_state, + picker, + }); + } + } + + // Responds to an incoming ResolverUpdate containing an Err in endpoints by + // forwarding it to all children unconditionally. Updates the picker as + // needed. + fn handle_resolver_error( + &mut self, + resolver_update: ResolverUpdate, + channel_controller: &mut dyn ChannelController, + ) -> Result<(), String> { + let err = format!( + "Received error from name resolver: {}", + resolver_update.endpoints.as_ref().unwrap_err() + ); + if self.child_manager.children().next().is_none() { + // We had no children so we must produce an erroring picker. + self.move_to_transient_failure(err.clone(), channel_controller); + return Err(err); + } + // Forward the error to each child, ignoring their responses. + let _ = self + .child_manager + .resolver_update(resolver_update, None, channel_controller); + self.update_picker(channel_controller); + Err(err) + } +} + +impl LbPolicy for LeastRequestPolicy { + type LbConfig = LeastRequestLoadBalancingConfig; + + fn resolver_update( + &mut self, + update: ResolverUpdate, + config: Option<&Self::LbConfig>, + channel_controller: &mut dyn ChannelController, + ) -> Result<(), String> { + if let Some(cfg) = config { + self.choice_count = cfg.choice_count; + } + + if update.endpoints.is_err() { + return self.handle_resolver_error(update, channel_controller); + } + + // Shard the update by endpoint. + let updates = update.endpoints.as_ref().unwrap().iter().map(|e| { + let update = ResolverUpdate { + attributes: crate::attributes::Attributes::default(), + endpoints: Ok(vec![e.clone()]), + service_config: update.service_config.clone(), + resolution_note: None, + }; + ChildUpdate { + child_identifier: e.clone(), + child_policy_builder: self.pick_first_builder.clone(), + child_update: Some((update, None)), + } + }); + self.child_manager + .update(updates, channel_controller) + .unwrap(); + + if self.child_manager.children().next().is_none() { + // There are no children remaining, so report this error and produce + // an erroring picker. + let err = "Received empty address list from the name resolver"; + self.move_to_transient_failure(err.into(), channel_controller); + return Err(err.into()); + } + + self.update_picker(channel_controller); + Ok(()) + } + + fn subchannel_update( + &mut self, + subchannel: Arc, + state: &SubchannelState, + channel_controller: &mut dyn ChannelController, + ) { + self.child_manager + .subchannel_update(subchannel, state, channel_controller); + self.update_picker(channel_controller); + } + + fn work(&mut self, channel_controller: &mut dyn ChannelController) { + self.child_manager.work(channel_controller); + self.update_picker(channel_controller); + } + + fn exit_idle(&mut self, channel_controller: &mut dyn ChannelController) { + self.child_manager.exit_idle(channel_controller); + self.update_picker(channel_controller); + } +} + +/// Register least request as a LbPolicy. +pub(crate) fn reg() { + START.call_once(|| { + GLOBAL_LB_REGISTRY.add_builder(LeastRequestBuilder {}); + }); +} + +#[derive(Clone, Debug)] +struct SubchannelWithCounter { + subchannel: Arc, + active_requests: Arc, +} + +#[derive(Debug)] +struct LeastRequestPicker { + subchannels: Vec, + choice_count: usize, +} + +impl Picker for LeastRequestPicker { + fn pick(&self, _request_headers: &RequestHeaders) -> PickResult { + let len = self.subchannels.len(); + if len == 0 { + return PickResult::Queue; + } + + let sample_limit = self.choice_count.min(len); + let mut best_idx: Option = None; + let mut best_active_requests = usize::MAX; + + for _ in 0..sample_limit { + let idx = if len == 1 { + 0 + } else { + rand::random_range(0..len) + }; + let active_reqs = self.subchannels[idx] + .active_requests + .load(Ordering::Relaxed); + if best_idx.is_none() || active_reqs < best_active_requests { + best_idx = Some(idx); + best_active_requests = active_reqs; + } + } + + let selected_idx = best_idx.unwrap(); + let selected = &self.subchannels[selected_idx]; + + selected.active_requests.fetch_add(1, Ordering::Relaxed); + + let active = Arc::new(std::sync::atomic::AtomicBool::new(true)); + let counter = selected.active_requests.clone(); + + struct ActiveRequestGuard { + counter: Arc, + active: Arc, + } + + impl Drop for ActiveRequestGuard { + fn drop(&mut self) { + if self.active.swap(false, Ordering::Relaxed) { + self.counter.fetch_sub(1, Ordering::Relaxed); + } + } + } + + let guard = ActiveRequestGuard { + counter: counter.clone(), + active: active.clone(), + }; + + let counter_clone = counter.clone(); + let on_complete = Box::new(move || { + if active.swap(false, Ordering::Relaxed) { + counter_clone.fetch_sub(1, Ordering::Relaxed); + } + let _ = &guard; + }); + + PickResult::Pick(crate::client::load_balancing::Pick { + subchannel: selected.subchannel.clone(), + metadata: crate::metadata::MetadataMap::new(), + on_complete: Some(on_complete), + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::client::load_balancing::test_utils::{TestChannelController, TestWorkScheduler}; + use crate::client::name_resolution::Address; + use crate::rt::default_runtime; + use std::sync::atomic::Ordering; + use std::sync::mpsc; + + #[derive(Debug, Clone)] + struct MockSubchannel { + address: Address, + } + + impl crate::client::load_balancing::subchannel::private::Sealed for MockSubchannel {} + impl crate::client::load_balancing::subchannel::DynHash for MockSubchannel { + fn dyn_hash(&self, state: &mut Box<&mut dyn std::hash::Hasher>) { + use std::hash::Hash; + self.address.hash(state); + } + } + impl crate::client::load_balancing::subchannel::DynPartialEq for MockSubchannel { + fn dyn_eq(&self, other: &&dyn std::any::Any) -> bool { + if let Some(other) = other.downcast_ref::() { + self.address == other.address + } else { + false + } + } + } + impl Subchannel for MockSubchannel { + fn address(&self) -> Address { + self.address.clone() + } + fn connect(&self) {} + } + + #[test] + fn test_config_parsing() { + let builder = LeastRequestBuilder {}; + + // Default choice count + let default_config = ParsedJsonLbConfig::new("{}").unwrap(); + let parsed = builder.parse_config(&default_config).unwrap().unwrap(); + assert_eq!(parsed.choice_count, 2); + + // Explicit valid choice count + let valid_config = ParsedJsonLbConfig::new("{\"choiceCount\": 5}").unwrap(); + let parsed = builder.parse_config(&valid_config).unwrap().unwrap(); + assert_eq!(parsed.choice_count, 5); + + // Clamped choice count + let high_config = ParsedJsonLbConfig::new("{\"choiceCount\": 15}").unwrap(); + let parsed = builder.parse_config(&high_config).unwrap().unwrap(); + assert_eq!(parsed.choice_count, 10); + + // Rejected choice count + let low_config = ParsedJsonLbConfig::new("{\"choiceCount\": 1}").unwrap(); + assert!(builder.parse_config(&low_config).is_err()); + } + + #[test] + fn test_picker_least_request_selection() { + let sc1 = Arc::new(MockSubchannel { + address: Address { + address: "127.0.0.1:80".to_string().into(), + ..Default::default() + }, + }) as Arc; + + let sc2 = Arc::new(MockSubchannel { + address: Address { + address: "127.0.0.1:81".to_string().into(), + ..Default::default() + }, + }) as Arc; + + let count1 = Arc::new(AtomicUsize::new(5)); + let count2 = Arc::new(AtomicUsize::new(2)); + + let picker = LeastRequestPicker { + subchannels: vec![ + SubchannelWithCounter { + subchannel: sc1.clone(), + active_requests: count1.clone(), + }, + SubchannelWithCounter { + subchannel: sc2.clone(), + active_requests: count2.clone(), + }, + ], + choice_count: 2, + }; + + // Run the pick in a loop up to 20 times since random sampling with replacement + // might occasionally select sc1 twice (with 25% probability). + let mut picked_sc2 = false; + for _ in 0..20 { + let res = picker.pick(&RequestHeaders::default()); + let pick = res.unwrap_pick(); + if pick.subchannel.address().address == "127.0.0.1:81".to_string().into() { + picked_sc2 = true; + // Active request count of the selected subchannel should have incremented + assert_eq!(count2.load(Ordering::Relaxed), 3); + let on_complete = pick.on_complete.unwrap(); + on_complete(); + assert_eq!(count2.load(Ordering::Relaxed), 2); + break; + } + } + assert!( + picked_sc2, + "sc2 (with fewer requests) was never picked in 20 attempts" + ); + } + + #[test] + fn test_picker_tie_breaking() { + let sc1 = Arc::new(MockSubchannel { + address: Address { + address: "127.0.0.1:80".to_string().into(), + ..Default::default() + }, + }) as Arc; + + let sc2 = Arc::new(MockSubchannel { + address: Address { + address: "127.0.0.1:81".to_string().into(), + ..Default::default() + }, + }) as Arc; + + let count1 = Arc::new(AtomicUsize::new(2)); + let count2 = Arc::new(AtomicUsize::new(2)); + + let picker = LeastRequestPicker { + subchannels: vec![ + SubchannelWithCounter { + subchannel: sc1.clone(), + active_requests: count1.clone(), + }, + SubchannelWithCounter { + subchannel: sc2.clone(), + active_requests: count2.clone(), + }, + ], + choice_count: 2, + }; + + // With identical active request counts, either sc1 or sc2 should be chosen + let res = picker.pick(&RequestHeaders::default()); + let pick = res.unwrap_pick(); + let chosen_addr = pick.subchannel.address().address.to_string(); + assert!(chosen_addr == "127.0.0.1:80" || chosen_addr == "127.0.0.1:81"); + } + + #[test] + fn test_picker_fewer_subchannels_than_choice_count() { + let sc1 = Arc::new(MockSubchannel { + address: Address { + address: "127.0.0.1:80".to_string().into(), + ..Default::default() + }, + }) as Arc; + + let sc2 = Arc::new(MockSubchannel { + address: Address { + address: "127.0.0.1:81".to_string().into(), + ..Default::default() + }, + }) as Arc; + + let count1 = Arc::new(AtomicUsize::new(5)); + let count2 = Arc::new(AtomicUsize::new(2)); + + let picker = LeastRequestPicker { + subchannels: vec![ + SubchannelWithCounter { + subchannel: sc1.clone(), + active_requests: count1.clone(), + }, + SubchannelWithCounter { + subchannel: sc2.clone(), + active_requests: count2.clone(), + }, + ], + // choice_count is 3, but only 2 subchannels are available + choice_count: 3, + }; + + // Picker should handle this gracefully by sampling both subchannels, + // and picking the one with fewer active requests (sc2). + let mut picked_sc2 = false; + for _ in 0..20 { + let res = picker.pick(&RequestHeaders::default()); + let pick = res.unwrap_pick(); + if pick.subchannel.address().address == "127.0.0.1:81".to_string().into() { + picked_sc2 = true; + // Active request count of the selected subchannel should have incremented + assert_eq!(count2.load(Ordering::Relaxed), 3); + let on_complete = pick.on_complete.unwrap(); + on_complete(); + assert_eq!(count2.load(Ordering::Relaxed), 2); + break; + } + } + assert!( + picked_sc2, + "sc2 (with fewer requests) was never picked in 20 attempts" + ); + } + + #[test] + fn test_picker_cancellation_drop_guard() { + let sc = Arc::new(MockSubchannel { + address: Address { + address: "127.0.0.1:80".to_string().into(), + ..Default::default() + }, + }) as Arc; + + let count = Arc::new(AtomicUsize::new(5)); + + let picker = LeastRequestPicker { + subchannels: vec![SubchannelWithCounter { + subchannel: sc, + active_requests: count.clone(), + }], + choice_count: 1, + }; + + // Pick once + let res = picker.pick(&RequestHeaders::default()); + assert_eq!(count.load(Ordering::Relaxed), 6); + + // Simulate cancellation/drop of Pick without calling on_complete callback + drop(res); + + // Count must have been decremented back to 5 by the Drop guard + assert_eq!(count.load(Ordering::Relaxed), 5); + } + + #[test] + fn test_policy_empty_resolver_update() { + let (tx_events, _rx_events) = mpsc::channel(); + let work_scheduler = Arc::new(TestWorkScheduler { + tx_events: tx_events.clone(), + }); + let child_manager = ChildManager::new(default_runtime(), work_scheduler); + pick_first::reg(); + let pick_first_builder = GLOBAL_LB_REGISTRY + .get_policy(pick_first::POLICY_NAME) + .unwrap(); + + let mut policy = LeastRequestPolicy::new(child_manager, pick_first_builder); + let mut tcc = TestChannelController { tx_events }; + + let update = ResolverUpdate { + endpoints: Ok(vec![]), + ..Default::default() + }; + + let res = policy.resolver_update(update, None, &mut tcc); + assert!(res.is_err()); + } +} diff --git a/grpc/src/client/load_balancing/mod.rs b/grpc/src/client/load_balancing/mod.rs index 487d495a4..842ac8ae8 100644 --- a/grpc/src/client/load_balancing/mod.rs +++ b/grpc/src/client/load_balancing/mod.rs @@ -43,6 +43,7 @@ use crate::rt::GrpcRuntime; pub(crate) mod child_manager; pub(crate) mod graceful_switch; pub(crate) mod lazy; +pub(crate) mod least_request; pub(crate) mod pick_first; pub(crate) mod round_robin; pub(crate) mod subchannel; diff --git a/grpc/src/client/service_config.rs b/grpc/src/client/service_config.rs index dc7464920..c3987344a 100644 --- a/grpc/src/client/service_config.rs +++ b/grpc/src/client/service_config.rs @@ -35,4 +35,5 @@ pub enum LbPolicyType { #[default] PickFirst, RoundRobin, + LeastRequest, } diff --git a/grpc/src/inmemory/mod.rs b/grpc/src/inmemory/mod.rs index 24354d2d9..0d16877ca 100644 --- a/grpc/src/inmemory/mod.rs +++ b/grpc/src/inmemory/mod.rs @@ -392,9 +392,20 @@ pub struct InMemoryResolverBuilder {} impl ResolverBuilder for InMemoryResolverBuilder { fn build(&self, target: &Target, options: ResolverOptions) -> Box { let path = target.path().strip_prefix('/').unwrap_or(target.path()); - let ids: Vec = path.split(',').map(|s| s.to_string()).collect(); + let (lb_policy, rest) = if let Some(stripped) = path.strip_prefix("leastrequest/") { + ( + crate::client::service_config::LbPolicyType::LeastRequest, + stripped, + ) + } else { + ( + crate::client::service_config::LbPolicyType::RoundRobin, + path, + ) + }; + let ids: Vec = rest.split(',').map(|s| s.to_string()).collect(); options.work_scheduler.schedule_work(); - Box::new(InMemoryResolver { ids }) + Box::new(InMemoryResolver { ids, lb_policy }) } fn scheme(&self) -> &str { @@ -408,6 +419,7 @@ impl ResolverBuilder for InMemoryResolverBuilder { struct InMemoryResolver { ids: Vec, + lb_policy: crate::client::service_config::LbPolicyType, } impl Resolver for InMemoryResolver { @@ -430,9 +442,7 @@ impl Resolver for InMemoryResolver { let _ = channel_controller.update(ResolverUpdate { endpoints: Ok(endpoints), service_config: Ok(Some(ServiceConfig { - load_balancing_policy: Some( - crate::client::service_config::LbPolicyType::RoundRobin, - ), + load_balancing_policy: Some(self.lb_policy.clone()), })), ..Default::default() }); @@ -493,4 +503,96 @@ mod tests { _ => panic!("expected trailers with error, got {:?}", item), } } + + #[tokio::test] + async fn test_in_memory_least_request_load_balancing() { + reg(); // Register transport and resolver + crate::client::load_balancing::least_request::reg(); // Register least request policy + + let backend1 = InMemoryListener::new(); + let backend2 = InMemoryListener::new(); + + let b1_id = backend1.id(); + let b2_id = backend2.id(); + + let (tx1, mut rx1) = tokio::sync::mpsc::channel(10); + let (tx2, mut rx2) = tokio::sync::mpsc::channel(10); + + let b1 = backend1.clone(); + let handle1 = tokio::spawn(async move { + if let Some(call) = b1.accept().await { + tx1.send(()).await.unwrap(); + // Keep the connection/stream open until the test is done + tokio::time::sleep(std::time::Duration::from_secs(5)).await; + drop(call); + } + }); + + let b2 = backend2.clone(); + let handle2 = tokio::spawn(async move { + if let Some(call) = b2.accept().await { + tx2.send(()).await.unwrap(); + // Keep the connection/stream open until the test is done + tokio::time::sleep(std::time::Duration::from_secs(5)).await; + drop(call); + } + }); + + // Construct target URI using the newly supported prefix + let target = format!("inmemory:///leastrequest/{},{}", b1_id, b2_id); + + let channel = crate::client::Channel::new( + &target, + crate::credentials::LocalChannelCredentials::new_arc(), + crate::client::ChannelOptions::default(), + ); + + // Make first invoke to establish an active request on whichever backend is ready first + let (_send1, _recv1) = channel + .invoke( + crate::core::RequestHeaders::new().with_method_name("/test/method"), + crate::client::CallOptions::default(), + ) + .await; + + // Loop and retry the second invoke until the other backend is ready and picked. + // Since the first backend is kept active (1 active request), the Least Request policy + // will immediately pick Backend 2 once it becomes ready (0 active requests). + let mut b1_called = false; + let mut b2_called = false; + + for _ in 0..50 { + // Check which backends have been called so far + if rx1.try_recv().is_ok() { + b1_called = true; + } + if rx2.try_recv().is_ok() { + b2_called = true; + } + + if b1_called && b2_called { + break; + } + + // Make a short-lived invoke. If it goes to the same backend, it will be dropped immediately, + // returning its active request count back to 1. + let invoke_future = channel.invoke( + crate::core::RequestHeaders::new().with_method_name("/test/method"), + crate::client::CallOptions::default(), + ); + if let Ok((_send_tmp, _recv_tmp)) = + tokio::time::timeout(std::time::Duration::from_millis(100), invoke_future).await + { + // Successfully made call + } + + tokio::time::sleep(std::time::Duration::from_millis(50)).await; + } + + assert!(b1_called, "Backend 1 was not called"); + assert!(b2_called, "Backend 2 was not called"); + + handle1.abort(); + handle2.abort(); + } }