Skip to content

Refactored RecoverableConnection #2670

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions sdk/eventhubs/azure_messaging_eventhubs/src/common/authorizer.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All Rights reserved
// Licensed under the MIT license.

use super::recoverable_connection::RecoverableConnection;
use super::recoverable::RecoverableConnection;
use crate::error::{ErrorKind, EventHubsError};
use async_lock::Mutex as AsyncMutex;
use azure_core::{
Expand All @@ -16,20 +16,20 @@ use rand::{thread_rng, Rng};
use std::collections::HashMap;
use std::sync::{Arc, Mutex as SyncMutex, OnceLock, Weak};
use time::{Duration, OffsetDateTime};
use tracing::{debug, error, trace};
use tracing::{debug, trace, warn};

// The number of seconds before token expiration that we wake up to refresh the token.
const TOKEN_REFRESH_BIAS: Duration = Duration::minutes(6); // By default, we refresh tokens 6 minutes before they expire.
const TOKEN_REFRESH_JITTER_MIN: i64 = -5; // Minimum jitter in seconds
const TOKEN_REFRESH_JITTER_MAX: i64 = 5; // Maximum jitter in seconds
const TOKEN_REFRESH_JITTER_MIN: Duration = Duration::seconds(-5); // Minimum jitter in seconds
const TOKEN_REFRESH_JITTER_MAX: Duration = Duration::seconds(5); // Maximum jitter in seconds

const EVENTHUBS_AUTHORIZATION_SCOPE: &str = "https://eventhubs.azure.net/.default";

#[derive(Debug)]
struct TokenRefreshTimes {
before_expiration_refresh_time: Duration,
jitter_min: i64,
jitter_max: i64,
jitter_min: Duration,
jitter_max: Duration,
}

impl Default for TokenRefreshTimes {
Expand Down Expand Up @@ -174,7 +174,7 @@ impl Authorizer {
async fn refresh_tokens_task(self: Arc<Self>) {
let result = self.refresh_tokens().await;
if let Err(e) = result {
error!("Error refreshing tokens: {e}");
warn!(err=?e, "Error refreshing tokens: {e}");
}
debug!("Token refresher task completed.");
}
Expand Down Expand Up @@ -250,10 +250,10 @@ impl Authorizer {

debug!("Token refresh times: {token_refresh_times:?}");

let expiration_jitter = Duration::seconds(
thread_rng()
.gen_range(token_refresh_times.jitter_min..token_refresh_times.jitter_max),
);
let jitter_min = token_refresh_times.jitter_min.whole_milliseconds() as i64;
let jitter_max = token_refresh_times.jitter_max.whole_milliseconds() as i64;
let expiration_jitter =
Duration::milliseconds(thread_rng().gen_range(jitter_min..jitter_max));
debug!("Expiration jitter: {expiration_jitter}");

token_refresh_bias = token_refresh_times
Expand Down Expand Up @@ -543,8 +543,8 @@ mod tests {
authorizer
.set_token_refresh_times(TokenRefreshTimes {
before_expiration_refresh_time: Duration::seconds(10),
jitter_min: -2,
jitter_max: 2,
jitter_min: Duration::seconds(-2),
jitter_max: Duration::seconds(2),
})
.unwrap();

Expand Down Expand Up @@ -611,8 +611,8 @@ mod tests {
authorizer
.set_token_refresh_times(TokenRefreshTimes {
before_expiration_refresh_time: Duration::seconds(5),
jitter_min: -1,
jitter_max: 1,
jitter_min: Duration::milliseconds(-500),
jitter_max: Duration::milliseconds(500),
})
.unwrap();

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All Rights reserved
// Licensed under the MIT license.

use super::recoverable_connection::RecoverableConnection;
use super::recoverable::RecoverableConnection;
use crate::{
error::{ErrorKind, EventHubsError},
models::{EventHubPartitionProperties, EventHubProperties},
Expand Down
4 changes: 2 additions & 2 deletions sdk/eventhubs/azure_messaging_eventhubs/src/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@

pub(crate) mod authorizer;
pub(crate) mod management;
pub(crate) mod recoverable_connection;
pub(crate) mod recoverable;
pub mod retry;
pub(crate) mod user_agent;

// Public API
pub(crate) use management::ManagementInstance;
pub(crate) use retry::{retry_azure_operation, RetryOptions};
pub(crate) use retry::retry_azure_operation;
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
// Copyright (c) Microsoft Corporation. All Rights reserved
// Licensed under the MIT license.

use super::RecoverableConnection;
use crate::{common::retry_azure_operation, RetryOptions};
use azure_core::{credentials::Secret, error::ErrorKind as AzureErrorKind, error::Result};
use azure_core_amqp::{
AmqpClaimsBasedSecurity, AmqpClaimsBasedSecurityApis, AmqpConnection, AmqpError, AmqpSession,
AmqpSessionApis,
};
use std::error::Error;
use std::sync::Arc;
use tracing::{debug, warn};

/// Thin wrapper around the [`AmqpClaimsBasedSecurityApis`] trait that implements the retry functionality.
///
/// A RecoverableClaimsBasedSecurity is a thin wrapper around the [`AmqpClaimsBasedSecurityApis`] trait which implements
/// the retry functionality. That allows implementations which call into the authorize_path API to not have
/// to worry about retrying the operation themselves.
pub(crate) struct RecoverableClaimsBasedSecurity {
recoverable_connection: Arc<RecoverableConnection>,
}

impl RecoverableClaimsBasedSecurity {
/// Creates a new RecoverableClaimsBasedSecurity.
///
/// # Arguments
///
/// * `recoverable_connection` - The recoverable connection to use for authorization.
pub(super) fn new(recoverable_connection: Arc<RecoverableConnection>) -> Self {
Self {
recoverable_connection,
}
}

pub(super) async fn create_claims_based_security(
connection: Arc<AmqpConnection>,
retry_options: &RetryOptions,
) -> Result<Arc<AmqpClaimsBasedSecurity>> {
retry_azure_operation(
|| async {
let session = AmqpSession::new();
session.begin(connection.as_ref(), None).await?;

let claims_based_security = Arc::new(AmqpClaimsBasedSecurity::new(session)?);

// Attach the claims_based_security client to the session.
claims_based_security.attach().await?;
Ok(claims_based_security)
},
retry_options,
Some(Self::should_retry_claims_based_security_response),
)
.await
}

fn should_retry_claims_based_security_response(e: &azure_core::Error) -> bool {
match e.kind() {
AzureErrorKind::Amqp => {
warn!(err=?e, "Amqp operation failed: {:?}", e.source());
if let Some(e) = e.source() {
debug!(err=?e, "Error: {e}");

if let Some(amqp_error) = e.downcast_ref::<Box<AmqpError>>() {
RecoverableConnection::should_retry_amqp_error(amqp_error)
} else if let Some(amqp_error) = e.downcast_ref::<AmqpError>() {
RecoverableConnection::should_retry_amqp_error(amqp_error)
} else {
debug!(err=?e, "Non AMQP error: {e}");
false
}
} else {
debug!("No source error found");
false
}
}
_ => {
debug!(err=?e, "Non AMQP error: {e}");
false
}
}
}
}

#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
impl AmqpClaimsBasedSecurityApis for RecoverableClaimsBasedSecurity {
async fn authorize_path(
&self,
path: String,
token_type: Option<String>,
secret: &Secret,
expires_on: time::OffsetDateTime,
) -> Result<()> {
let result = retry_azure_operation(
|| {
let path = path.clone();
let token_type = token_type.clone();
let secret = secret.clone();

async move {
let claims_based_security_client =
self.recoverable_connection.ensure_amqp_cbs().await?;
claims_based_security_client
.authorize_path(path, token_type, &secret, expires_on)
.await
}
},
&self.recoverable_connection.retry_options,
Some(Self::should_retry_claims_based_security_response),
)
.await?;
Ok(result)
}

async fn attach(&self) -> azure_core::Result<()> {
unimplemented!("AmqpClaimsBasedSecurityClient does not support attach operation");
}

async fn detach(self) -> azure_core::Result<()> {
unimplemented!("AmqpClaimsBasedSecurityClient does not support detach operation");
}
}
Loading