Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,21 @@ public static bool DefaultManagedIdentity(HttpResponse response, Exception excep
};
}

/// <summary>
/// Retry policy specific to Imds v1 and v2 Probe.
/// Extends Imds retry policy but excludes 404 status code.
/// </summary>
public static bool ImdsProbe(HttpResponse response, Exception exception)
{
if (!Imds(response, exception))
{
return false;
}

// If Imds would retry but the status code is 404, don't retry
return (int)response.StatusCode is not 404;
}

/// <summary>
/// Retry policy specific to IMDS Managed Identity.
/// </summary>
Expand Down Expand Up @@ -62,21 +77,6 @@ public static bool RegionDiscovery(HttpResponse response, Exception exception)
return (int)response.StatusCode is not (404 or 408);
}

/// <summary>
/// Retry policy specific to CSR Metadata Probe.
/// Extends Imds retry policy but excludes 404 status code.
/// </summary>
public static bool CsrMetadataProbe(HttpResponse response, Exception exception)
{
if (!Imds(response, exception))
{
return false;
}

// If Imds would retry but the status code is 404, don't retry
return (int)response.StatusCode is not 404;
}

/// <summary>
/// Retry condition for /token and /authorize endpoints
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@

namespace Microsoft.Identity.Client.Http.Retry
{
internal class CsrMetadataProbeRetryPolicy : ImdsRetryPolicy
internal class ImdsProbeRetryPolicy : ImdsRetryPolicy
{
protected override bool ShouldRetry(HttpResponse response, Exception exception)
{
return HttpRetryConditions.CsrMetadataProbe(response, exception);
return HttpRetryConditions.ImdsProbe(response, exception);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ public virtual IRetryPolicy GetRetryPolicy(RequestType requestType)
case RequestType.STS:
case RequestType.ManagedIdentityDefault:
return new DefaultRetryPolicy(requestType);
case RequestType.ImdsProbe:
return new ImdsProbeRetryPolicy();
case RequestType.Imds:
return new ImdsRetryPolicy();
case RequestType.RegionDiscovery:
return new RegionDiscoveryRetryPolicy();
case RequestType.CsrMetadataProbe:
return new CsrMetadataProbeRetryPolicy();
default:
throw new ArgumentOutOfRangeException(nameof(requestType), requestType, "Unknown request type.");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,21 @@
using Microsoft.Identity.Client.ApiConfig.Parameters;
using Microsoft.Identity.Client.Core;
using Microsoft.Identity.Client.Http;
using Microsoft.Identity.Client.Http.Retry;
using Microsoft.Identity.Client.Internal;
using Microsoft.Identity.Client.ManagedIdentity.V2;
using Microsoft.Identity.Client.OAuth2;

namespace Microsoft.Identity.Client.ManagedIdentity
{
internal class ImdsManagedIdentitySource : AbstractManagedIdentity
{
// IMDS constants. Docs for IMDS are available here https://docs.microsoft.com/azure/active-directory/managed-identities-azure-resources/how-to-use-vm-token#get-a-token-using-http
// used in unit tests as well
public const string ApiVersionQueryParam = "api-version";
public const string DefaultImdsBaseEndpoint= "http://169.254.169.254";
private const string ImdsTokenPath = "/metadata/identity/oauth2/token";
public const string ImdsApiVersion = "2018-02-01";
public const string ImdsTokenPath = "/metadata/identity/oauth2/token";

private const string DefaultMessage = "[Managed Identity] Service request failed.";

Expand All @@ -36,6 +40,11 @@ internal class ImdsManagedIdentitySource : AbstractManagedIdentity

private static string s_cachedBaseEndpoint = null;

public static AbstractManagedIdentity Create(RequestContext requestContext)
{
return new ImdsManagedIdentitySource(requestContext);
}

internal ImdsManagedIdentitySource(RequestContext requestContext) :
base(requestContext, ManagedIdentitySource.Imds)
{
Expand All @@ -51,7 +60,7 @@ protected override Task<ManagedIdentityRequest> CreateRequestAsync(string resour
ManagedIdentityRequest request = new(HttpMethod.Get, _imdsEndpoint);

request.Headers.Add("Metadata", "true");
request.QueryParameters["api-version"] = ImdsApiVersion;
request.QueryParameters[ApiVersionQueryParam] = ImdsApiVersion;
request.QueryParameters["resource"] = resource;

switch (_requestContext.ServiceBundle.Config.ManagedIdentityId.IdType)
Expand Down Expand Up @@ -211,5 +220,106 @@ public static Uri GetValidatedEndpoint(

return builder.Uri;
}

public static string ImdsQueryParamsHelper(
RequestContext requestContext,
string apiVersionQueryParam,
string imdsApiVersion)
{
var queryParams = $"{apiVersionQueryParam}={imdsApiVersion}";

var userAssignedIdQueryParam = GetUserAssignedIdQueryParam(
requestContext.ServiceBundle.Config.ManagedIdentityId.IdType,
requestContext.ServiceBundle.Config.ManagedIdentityId.UserAssignedId,
requestContext.Logger);

if (userAssignedIdQueryParam != null)
{
queryParams += $"&{userAssignedIdQueryParam.Value.Key}={userAssignedIdQueryParam.Value.Value}";
}

return queryParams;
}

public static async Task<bool> ProbeImdsEndpointAsync(
RequestContext requestContext,
ImdsVersion imdsVersion,
CancellationToken cancellationToken)
{
string apiVersionQueryParam;
string imdsApiVersion;
string imdsEndpoint;
string imdsStringHelper;

switch (imdsVersion)
{
case ImdsVersion.V2:
#if NET462
requestContext.Logger.Info("[Managed Identity] IMDSv2 flow is not supported on .NET Framework 4.6.2. Cryptographic operations required for managed identity authentication are unavailable on this platform. Skipping IMDSv2 probe.");
return false;
#else
apiVersionQueryParam = ImdsV2ManagedIdentitySource.ApiVersionQueryParam;
imdsApiVersion = ImdsV2ManagedIdentitySource.ImdsV2ApiVersion;
imdsEndpoint = ImdsV2ManagedIdentitySource.CsrMetadataPath;
imdsStringHelper = "IMDSv2";
break;
#endif
case ImdsVersion.V1:
apiVersionQueryParam = ApiVersionQueryParam;
imdsApiVersion = ImdsApiVersion;
imdsEndpoint = ImdsTokenPath;
imdsStringHelper = "IMDSv1";
break;

default:
throw new ArgumentOutOfRangeException(nameof(imdsVersion), imdsVersion, null);
}

var queryParams = ImdsQueryParamsHelper(requestContext, apiVersionQueryParam, imdsApiVersion);

// probe omits the "Metadata: true" header and then treats 400 Bad Request as success
var headers = new Dictionary<string, string>
{
{ OAuth2Header.XMsCorrelationId, requestContext.CorrelationId.ToString() }
};

IRetryPolicyFactory retryPolicyFactory = requestContext.ServiceBundle.Config.RetryPolicyFactory;
IRetryPolicy retryPolicy = retryPolicyFactory.GetRetryPolicy(RequestType.ImdsProbe);

HttpResponse response = null;

try
{
response = await requestContext.ServiceBundle.HttpManager.SendRequestAsync(
GetValidatedEndpoint(requestContext.Logger, imdsEndpoint, queryParams),
headers,
body: null,
method: HttpMethod.Get,
logger: requestContext.Logger,
doNotThrow: false,
mtlsCertificate: null,
validateServerCertificate: null,
cancellationToken: cancellationToken,
retryPolicy: retryPolicy)
.ConfigureAwait(false);
}
catch (Exception ex)
{
requestContext.Logger.Info($"[Managed Identity] {imdsStringHelper} probe endpoint failure. Exception occurred while sending request to probe endpoint: {ex}");
return false;
}

// probe omits the "Metadata: true" header and then treats 400 Bad Request as success
if (response.StatusCode == HttpStatusCode.BadRequest)
{
requestContext.Logger.Info(() => $"[Managed Identity] {imdsStringHelper} managed identity is available.");
return true;
}
else
{
requestContext.Logger.Info(() => $"[Managed Identity] {imdsStringHelper} managed identity is not available. Status code: {response.StatusCode}, Body: {response.Body}");
return false;
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// Licensed under the MIT License.

using System;
using System.Collections.Concurrent;
using System.IO;
using System.Security.Cryptography.X509Certificates;
using System.Threading;
Expand Down Expand Up @@ -41,12 +40,15 @@ internal async Task<ManagedIdentityResponse> SendTokenRequestForManagedIdentityA
AcquireTokenForManagedIdentityParameters parameters,
CancellationToken cancellationToken)
{
AbstractManagedIdentity msi = await GetOrSelectManagedIdentitySourceAsync(requestContext, parameters.IsMtlsPopRequested).ConfigureAwait(false);
AbstractManagedIdentity msi = await GetOrSelectManagedIdentitySourceAsync(requestContext, parameters.IsMtlsPopRequested, cancellationToken).ConfigureAwait(false);
return await msi.AuthenticateAsync(parameters, cancellationToken).ConfigureAwait(false);
}

// This method tries to create managed identity source for different sources, if none is created then defaults to IMDS.
private async Task<AbstractManagedIdentity> GetOrSelectManagedIdentitySourceAsync(RequestContext requestContext, bool isMtlsPopRequested)
private async Task<AbstractManagedIdentity> GetOrSelectManagedIdentitySourceAsync(
RequestContext requestContext,
bool isMtlsPopRequested,
CancellationToken cancellationToken)
{
using (requestContext.Logger.LogMethodDuration())
{
Expand All @@ -58,28 +60,27 @@ private async Task<AbstractManagedIdentity> GetOrSelectManagedIdentitySourceAsyn
if (s_sourceName == ManagedIdentitySource.None)
{
// First invocation: detect and cache
source = await GetManagedIdentitySourceAsync(requestContext, isMtlsPopRequested).ConfigureAwait(false);
source = await GetManagedIdentitySourceAsync(requestContext, isMtlsPopRequested, cancellationToken).ConfigureAwait(false);
}
else
{
// Reuse cached value
source = s_sourceName;
}

// If the source has already been set to ImdsV2 (via this method,
// or GetManagedIdentitySourceAsync in ManagedIdentityApplication.cs) and mTLS PoP was NOT requested
// In this case, we need to fall back to ImdsV1, because ImdsV2 currently only supports mTLS PoP requests
// If the source has already been set to ImdsV2 (via this method, or GetManagedIdentitySourceAsync in ManagedIdentityApplication.cs)
// and mTLS PoP was NOT requested: fall back to ImdsV1, because ImdsV2 currently only supports mTLS PoP requests
if (source == ManagedIdentitySource.ImdsV2 && !isMtlsPopRequested)
{
requestContext.Logger.Info("[Managed Identity] ImdsV2 detected, but mTLS PoP was not requested. Falling back to ImdsV1 for this request only. Please use the \"WithMtlsProofOfPossession\" API to request a token via ImdsV2.");
// Do NOT modify s_sourceName; keep cached ImdsV2 so future PoP
// requests can leverage it.
source = ManagedIdentitySource.DefaultToImds;
source = ManagedIdentitySource.Imds;
}

// If the source is determined to be ImdsV1 and mTLS PoP was requested,
// throw an exception since ImdsV1 does not support mTLS PoP
if (source == ManagedIdentitySource.DefaultToImds && isMtlsPopRequested)
if (source == ManagedIdentitySource.Imds && isMtlsPopRequested)
{
throw new MsalClientException(
MsalError.MtlsPopTokenNotSupportedinImdsV1,
Expand All @@ -94,7 +95,8 @@ private async Task<AbstractManagedIdentity> GetOrSelectManagedIdentitySourceAsyn
ManagedIdentitySource.CloudShell => CloudShellManagedIdentitySource.Create(requestContext),
ManagedIdentitySource.AzureArc => AzureArcManagedIdentitySource.Create(requestContext),
ManagedIdentitySource.ImdsV2 => ImdsV2ManagedIdentitySource.Create(requestContext),
_ => new ImdsManagedIdentitySource(requestContext)
ManagedIdentitySource.Imds => ImdsManagedIdentitySource.Create(requestContext),
_ => throw new MsalClientException(MsalError.ManagedIdentityAllSourcesUnavailable, MsalErrorMessage.ManagedIdentityAllSourcesUnavailable)
};
}
}
Expand All @@ -103,39 +105,58 @@ private async Task<AbstractManagedIdentity> GetOrSelectManagedIdentitySourceAsyn
// This method is perf sensitive any changes should be benchmarked.
internal async Task<ManagedIdentitySource> GetManagedIdentitySourceAsync(
RequestContext requestContext,
bool isMtlsPopRequested)
bool isMtlsPopRequested,
CancellationToken cancellationToken)
{
// First check env vars to avoid the probe if possible
ManagedIdentitySource source = GetManagedIdentitySourceNoImdsV2(requestContext.Logger);

// If a source is detected via env vars, or
// a source wasn't detected (it defaulted to ImdsV1) and MtlsPop was NOT requested,
// use the source.
// (don't trigger the ImdsV2 probe endpoint if MtlsPop was NOT requested)
if (source != ManagedIdentitySource.DefaultToImds || !isMtlsPopRequested)
ManagedIdentitySource source = GetManagedIdentitySourceNoImds(requestContext.Logger);
if (source != ManagedIdentitySource.None)
{
s_sourceName = source;
return source;
}

// Otherwise, probe IMDSv2
var response = await ImdsV2ManagedIdentitySource.GetCsrMetadataAsync(requestContext, probeMode: true).ConfigureAwait(false);
if (response != null)
// skip the ImdsV2 probe if MtlsPop was NOT requested
if (isMtlsPopRequested)
{
var imdsV2Response = await ImdsManagedIdentitySource.ProbeImdsEndpointAsync(requestContext, ImdsVersion.V2, cancellationToken).ConfigureAwait(false);
if (imdsV2Response)
{
requestContext.Logger.Info("[Managed Identity] ImdsV2 detected.");
s_sourceName = ManagedIdentitySource.ImdsV2;
return s_sourceName;
}
}
else
{
requestContext.Logger.Info("[Managed Identity] Mtls Pop was not requested; skipping ImdsV2 probe.");
}

var imdsV1Response = await ImdsManagedIdentitySource.ProbeImdsEndpointAsync(requestContext, ImdsVersion.V1, cancellationToken).ConfigureAwait(false);
if (imdsV1Response)
{
requestContext.Logger.Info("[Managed Identity] ImdsV2 detected.");
s_sourceName = ManagedIdentitySource.ImdsV2;
requestContext.Logger.Info("[Managed Identity] ImdsV1 detected.");
s_sourceName = ManagedIdentitySource.Imds;
return s_sourceName;
}

requestContext.Logger.Info("[Managed Identity] IMDSv2 probe failed. Defaulting to IMDSv1.");
s_sourceName = ManagedIdentitySource.DefaultToImds;
requestContext.Logger.Info($"[Managed Identity] {MsalErrorMessage.ManagedIdentityAllSourcesUnavailable}");
s_sourceName = ManagedIdentitySource.None;
return s_sourceName;
}

// Detect managed identity source based on the availability of environment variables.
// The result of this method is not cached because reading environment variables is cheap.
// This method is perf sensitive any changes should be benchmarked.
internal static ManagedIdentitySource GetManagedIdentitySourceNoImdsV2(ILoggerAdapter logger = null)
/// <summary>
/// Detects the managed identity source based on the availability of environment variables.
/// It does not probe IMDS, but it checks for all other sources.
/// This method does not cache its result, as reading environment variables is inexpensive.
/// It is performance sensitive; any changes should be benchmarked.
/// </summary>
/// <param name="logger">Optional logger for diagnostic output.</param>
/// <returns>
/// The detected <see cref="ManagedIdentitySource"/> based on environment variables.
/// Returns <c>ManagedIdentitySource.None</c> if no environment-based source is detected.
/// </returns>
internal static ManagedIdentitySource GetManagedIdentitySourceNoImds(ILoggerAdapter logger = null)
{
string identityEndpoint = EnvironmentVariables.IdentityEndpoint;
string identityHeader = EnvironmentVariables.IdentityHeader;
Expand Down Expand Up @@ -177,7 +198,7 @@ internal static ManagedIdentitySource GetManagedIdentitySourceNoImdsV2(ILoggerAd
}
else
{
return ManagedIdentitySource.DefaultToImds;
return ManagedIdentitySource.None;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ public enum ManagedIdentitySource
/// Indicates that the source is defaulted to IMDS since no environment variables are set.
/// This is used to detect the managed identity source.
/// </summary>
[Obsolete("In use only to support the now obsolete GetManagedIdentitySource API. Will be removed in a future version. Use GetManagedIdentitySourceAsync instead.")]
DefaultToImds,

/// <summary>
Expand Down
Loading
Loading