Skip to content

Commit 964a074

Browse files
IMDS Source Detection Logic Improvement (#5602)
* Implemented imds v1 and v2 probing. Throwing error instead of falling back to imdsv1 when determining source. * Imds probing test infrastructure + improved and fixed all broken unit tests in MangedIdentityTests.cs * Fixed + Improved all tests in ImdsV2Tests.cs * Fixed + improved ImdsTests * Implemented GitHub feedback * Implemented GitHub feedback * Implemented GitHub feedback --------- Co-authored-by: Bogdan Gavril <[email protected]>
1 parent 4dd269e commit 964a074

32 files changed

+631
-402
lines changed

src/client/Microsoft.Identity.Client/Http/Retry/HttpRetryCondition.cs renamed to src/client/Microsoft.Identity.Client/Http/Retry/HttpRetryConditions.cs

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,21 @@ public static bool DefaultManagedIdentity(HttpResponse response, Exception excep
2727
};
2828
}
2929

30+
/// <summary>
31+
/// Retry policy specific to Imds v1 and v2 Probe.
32+
/// Extends Imds retry policy but excludes 404 status code.
33+
/// </summary>
34+
public static bool ImdsProbe(HttpResponse response, Exception exception)
35+
{
36+
if (!Imds(response, exception))
37+
{
38+
return false;
39+
}
40+
41+
// If Imds would retry but the status code is 404, don't retry
42+
return (int)response.StatusCode is not 404;
43+
}
44+
3045
/// <summary>
3146
/// Retry policy specific to IMDS Managed Identity.
3247
/// </summary>
@@ -62,21 +77,6 @@ public static bool RegionDiscovery(HttpResponse response, Exception exception)
6277
return (int)response.StatusCode is not (404 or 408);
6378
}
6479

65-
/// <summary>
66-
/// Retry policy specific to CSR Metadata Probe.
67-
/// Extends Imds retry policy but excludes 404 status code.
68-
/// </summary>
69-
public static bool CsrMetadataProbe(HttpResponse response, Exception exception)
70-
{
71-
if (!Imds(response, exception))
72-
{
73-
return false;
74-
}
75-
76-
// If Imds would retry but the status code is 404, don't retry
77-
return (int)response.StatusCode is not 404;
78-
}
79-
8080
/// <summary>
8181
/// Retry condition for /token and /authorize endpoints
8282
/// </summary>

src/client/Microsoft.Identity.Client/Http/Retry/CsrMetadataProbeRetryPolicy.cs renamed to src/client/Microsoft.Identity.Client/Http/Retry/ImdsProbeRetryPolicy.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55

66
namespace Microsoft.Identity.Client.Http.Retry
77
{
8-
internal class CsrMetadataProbeRetryPolicy : ImdsRetryPolicy
8+
internal class ImdsProbeRetryPolicy : ImdsRetryPolicy
99
{
1010
protected override bool ShouldRetry(HttpResponse response, Exception exception)
1111
{
12-
return HttpRetryConditions.CsrMetadataProbe(response, exception);
12+
return HttpRetryConditions.ImdsProbe(response, exception);
1313
}
1414
}
1515
}

src/client/Microsoft.Identity.Client/Http/Retry/RetryPolicyFactory.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@ public virtual IRetryPolicy GetRetryPolicy(RequestType requestType)
1414
case RequestType.STS:
1515
case RequestType.ManagedIdentityDefault:
1616
return new DefaultRetryPolicy(requestType);
17+
case RequestType.ImdsProbe:
18+
return new ImdsProbeRetryPolicy();
1719
case RequestType.Imds:
1820
return new ImdsRetryPolicy();
1921
case RequestType.RegionDiscovery:
2022
return new RegionDiscoveryRetryPolicy();
21-
case RequestType.CsrMetadataProbe:
22-
return new CsrMetadataProbeRetryPolicy();
2323
default:
2424
throw new ArgumentOutOfRangeException(nameof(requestType), requestType, "Unknown request type.");
2525
}

src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsManagedIdentitySource.cs

Lines changed: 112 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,21 @@
1111
using Microsoft.Identity.Client.ApiConfig.Parameters;
1212
using Microsoft.Identity.Client.Core;
1313
using Microsoft.Identity.Client.Http;
14+
using Microsoft.Identity.Client.Http.Retry;
1415
using Microsoft.Identity.Client.Internal;
16+
using Microsoft.Identity.Client.ManagedIdentity.V2;
17+
using Microsoft.Identity.Client.OAuth2;
1518

1619
namespace Microsoft.Identity.Client.ManagedIdentity
1720
{
1821
internal class ImdsManagedIdentitySource : AbstractManagedIdentity
1922
{
2023
// IMDS constants. Docs for IMDS are available here https://docs.microsoft.com/azure/active-directory/managed-identities-azure-resources/how-to-use-vm-token#get-a-token-using-http
2124
// used in unit tests as well
25+
public const string ApiVersionQueryParam = "api-version";
2226
public const string DefaultImdsBaseEndpoint= "http://169.254.169.254";
23-
private const string ImdsTokenPath = "/metadata/identity/oauth2/token";
2427
public const string ImdsApiVersion = "2018-02-01";
28+
public const string ImdsTokenPath = "/metadata/identity/oauth2/token";
2529

2630
private const string DefaultMessage = "[Managed Identity] Service request failed.";
2731

@@ -36,6 +40,11 @@ internal class ImdsManagedIdentitySource : AbstractManagedIdentity
3640

3741
private static string s_cachedBaseEndpoint = null;
3842

43+
public static AbstractManagedIdentity Create(RequestContext requestContext)
44+
{
45+
return new ImdsManagedIdentitySource(requestContext);
46+
}
47+
3948
internal ImdsManagedIdentitySource(RequestContext requestContext) :
4049
base(requestContext, ManagedIdentitySource.Imds)
4150
{
@@ -51,7 +60,7 @@ protected override Task<ManagedIdentityRequest> CreateRequestAsync(string resour
5160
ManagedIdentityRequest request = new(HttpMethod.Get, _imdsEndpoint);
5261

5362
request.Headers.Add("Metadata", "true");
54-
request.QueryParameters["api-version"] = ImdsApiVersion;
63+
request.QueryParameters[ApiVersionQueryParam] = ImdsApiVersion;
5564
request.QueryParameters["resource"] = resource;
5665

5766
switch (_requestContext.ServiceBundle.Config.ManagedIdentityId.IdType)
@@ -211,5 +220,106 @@ public static Uri GetValidatedEndpoint(
211220

212221
return builder.Uri;
213222
}
223+
224+
public static string ImdsQueryParamsHelper(
225+
RequestContext requestContext,
226+
string apiVersionQueryParam,
227+
string imdsApiVersion)
228+
{
229+
var queryParams = $"{apiVersionQueryParam}={imdsApiVersion}";
230+
231+
var userAssignedIdQueryParam = GetUserAssignedIdQueryParam(
232+
requestContext.ServiceBundle.Config.ManagedIdentityId.IdType,
233+
requestContext.ServiceBundle.Config.ManagedIdentityId.UserAssignedId,
234+
requestContext.Logger);
235+
236+
if (userAssignedIdQueryParam != null)
237+
{
238+
queryParams += $"&{userAssignedIdQueryParam.Value.Key}={userAssignedIdQueryParam.Value.Value}";
239+
}
240+
241+
return queryParams;
242+
}
243+
244+
public static async Task<bool> ProbeImdsEndpointAsync(
245+
RequestContext requestContext,
246+
ImdsVersion imdsVersion,
247+
CancellationToken cancellationToken)
248+
{
249+
string apiVersionQueryParam;
250+
string imdsApiVersion;
251+
string imdsEndpoint;
252+
string imdsStringHelper;
253+
254+
switch (imdsVersion)
255+
{
256+
case ImdsVersion.V2:
257+
#if NET462
258+
requestContext.Logger.Info("[Managed Identity] IMDSv2 flow is not supported on .NET Framework 4.6.2. Cryptographic operations required for managed identity authentication are unavailable on this platform. Skipping IMDSv2 probe.");
259+
return false;
260+
#else
261+
apiVersionQueryParam = ImdsV2ManagedIdentitySource.ApiVersionQueryParam;
262+
imdsApiVersion = ImdsV2ManagedIdentitySource.ImdsV2ApiVersion;
263+
imdsEndpoint = ImdsV2ManagedIdentitySource.CsrMetadataPath;
264+
imdsStringHelper = "IMDSv2";
265+
break;
266+
#endif
267+
case ImdsVersion.V1:
268+
apiVersionQueryParam = ApiVersionQueryParam;
269+
imdsApiVersion = ImdsApiVersion;
270+
imdsEndpoint = ImdsTokenPath;
271+
imdsStringHelper = "IMDSv1";
272+
break;
273+
274+
default:
275+
throw new ArgumentOutOfRangeException(nameof(imdsVersion), imdsVersion, null);
276+
}
277+
278+
var queryParams = ImdsQueryParamsHelper(requestContext, apiVersionQueryParam, imdsApiVersion);
279+
280+
// probe omits the "Metadata: true" header and then treats 400 Bad Request as success
281+
var headers = new Dictionary<string, string>
282+
{
283+
{ OAuth2Header.XMsCorrelationId, requestContext.CorrelationId.ToString() }
284+
};
285+
286+
IRetryPolicyFactory retryPolicyFactory = requestContext.ServiceBundle.Config.RetryPolicyFactory;
287+
IRetryPolicy retryPolicy = retryPolicyFactory.GetRetryPolicy(RequestType.ImdsProbe);
288+
289+
HttpResponse response = null;
290+
291+
try
292+
{
293+
response = await requestContext.ServiceBundle.HttpManager.SendRequestAsync(
294+
GetValidatedEndpoint(requestContext.Logger, imdsEndpoint, queryParams),
295+
headers,
296+
body: null,
297+
method: HttpMethod.Get,
298+
logger: requestContext.Logger,
299+
doNotThrow: false,
300+
mtlsCertificate: null,
301+
validateServerCertificate: null,
302+
cancellationToken: cancellationToken,
303+
retryPolicy: retryPolicy)
304+
.ConfigureAwait(false);
305+
}
306+
catch (Exception ex)
307+
{
308+
requestContext.Logger.Info($"[Managed Identity] {imdsStringHelper} probe endpoint failure. Exception occurred while sending request to probe endpoint: {ex}");
309+
return false;
310+
}
311+
312+
// probe omits the "Metadata: true" header and then treats 400 Bad Request as success
313+
if (response.StatusCode == HttpStatusCode.BadRequest)
314+
{
315+
requestContext.Logger.Info(() => $"[Managed Identity] {imdsStringHelper} managed identity is available.");
316+
return true;
317+
}
318+
else
319+
{
320+
requestContext.Logger.Info(() => $"[Managed Identity] {imdsStringHelper} managed identity is not available. Status code: {response.StatusCode}, Body: {response.Body}");
321+
return false;
322+
}
323+
}
214324
}
215325
}

src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityClient.cs

Lines changed: 51 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
// Licensed under the MIT License.
33

44
using System;
5-
using System.Collections.Concurrent;
65
using System.IO;
76
using System.Security.Cryptography.X509Certificates;
87
using System.Threading;
@@ -41,12 +40,15 @@ internal async Task<ManagedIdentityResponse> SendTokenRequestForManagedIdentityA
4140
AcquireTokenForManagedIdentityParameters parameters,
4241
CancellationToken cancellationToken)
4342
{
44-
AbstractManagedIdentity msi = await GetOrSelectManagedIdentitySourceAsync(requestContext, parameters.IsMtlsPopRequested).ConfigureAwait(false);
43+
AbstractManagedIdentity msi = await GetOrSelectManagedIdentitySourceAsync(requestContext, parameters.IsMtlsPopRequested, cancellationToken).ConfigureAwait(false);
4544
return await msi.AuthenticateAsync(parameters, cancellationToken).ConfigureAwait(false);
4645
}
4746

4847
// This method tries to create managed identity source for different sources, if none is created then defaults to IMDS.
49-
private async Task<AbstractManagedIdentity> GetOrSelectManagedIdentitySourceAsync(RequestContext requestContext, bool isMtlsPopRequested)
48+
private async Task<AbstractManagedIdentity> GetOrSelectManagedIdentitySourceAsync(
49+
RequestContext requestContext,
50+
bool isMtlsPopRequested,
51+
CancellationToken cancellationToken)
5052
{
5153
using (requestContext.Logger.LogMethodDuration())
5254
{
@@ -58,28 +60,27 @@ private async Task<AbstractManagedIdentity> GetOrSelectManagedIdentitySourceAsyn
5860
if (s_sourceName == ManagedIdentitySource.None)
5961
{
6062
// First invocation: detect and cache
61-
source = await GetManagedIdentitySourceAsync(requestContext, isMtlsPopRequested).ConfigureAwait(false);
63+
source = await GetManagedIdentitySourceAsync(requestContext, isMtlsPopRequested, cancellationToken).ConfigureAwait(false);
6264
}
6365
else
6466
{
6567
// Reuse cached value
6668
source = s_sourceName;
6769
}
6870

69-
// If the source has already been set to ImdsV2 (via this method,
70-
// or GetManagedIdentitySourceAsync in ManagedIdentityApplication.cs) and mTLS PoP was NOT requested
71-
// In this case, we need to fall back to ImdsV1, because ImdsV2 currently only supports mTLS PoP requests
71+
// If the source has already been set to ImdsV2 (via this method, or GetManagedIdentitySourceAsync in ManagedIdentityApplication.cs)
72+
// and mTLS PoP was NOT requested: fall back to ImdsV1, because ImdsV2 currently only supports mTLS PoP requests
7273
if (source == ManagedIdentitySource.ImdsV2 && !isMtlsPopRequested)
7374
{
7475
requestContext.Logger.Info("[Managed Identity] ImdsV2 detected, but mTLS PoP was not requested. Falling back to ImdsV1 for this request only. Please use the \"WithMtlsProofOfPossession\" API to request a token via ImdsV2.");
7576
// Do NOT modify s_sourceName; keep cached ImdsV2 so future PoP
7677
// requests can leverage it.
77-
source = ManagedIdentitySource.DefaultToImds;
78+
source = ManagedIdentitySource.Imds;
7879
}
7980

8081
// If the source is determined to be ImdsV1 and mTLS PoP was requested,
8182
// throw an exception since ImdsV1 does not support mTLS PoP
82-
if (source == ManagedIdentitySource.DefaultToImds && isMtlsPopRequested)
83+
if (source == ManagedIdentitySource.Imds && isMtlsPopRequested)
8384
{
8485
throw new MsalClientException(
8586
MsalError.MtlsPopTokenNotSupportedinImdsV1,
@@ -94,7 +95,8 @@ private async Task<AbstractManagedIdentity> GetOrSelectManagedIdentitySourceAsyn
9495
ManagedIdentitySource.CloudShell => CloudShellManagedIdentitySource.Create(requestContext),
9596
ManagedIdentitySource.AzureArc => AzureArcManagedIdentitySource.Create(requestContext),
9697
ManagedIdentitySource.ImdsV2 => ImdsV2ManagedIdentitySource.Create(requestContext),
97-
_ => new ImdsManagedIdentitySource(requestContext)
98+
ManagedIdentitySource.Imds => ImdsManagedIdentitySource.Create(requestContext),
99+
_ => throw new MsalClientException(MsalError.ManagedIdentityAllSourcesUnavailable, MsalErrorMessage.ManagedIdentityAllSourcesUnavailable)
98100
};
99101
}
100102
}
@@ -103,39 +105,58 @@ private async Task<AbstractManagedIdentity> GetOrSelectManagedIdentitySourceAsyn
103105
// This method is perf sensitive any changes should be benchmarked.
104106
internal async Task<ManagedIdentitySource> GetManagedIdentitySourceAsync(
105107
RequestContext requestContext,
106-
bool isMtlsPopRequested)
108+
bool isMtlsPopRequested,
109+
CancellationToken cancellationToken)
107110
{
108111
// First check env vars to avoid the probe if possible
109-
ManagedIdentitySource source = GetManagedIdentitySourceNoImdsV2(requestContext.Logger);
110-
111-
// If a source is detected via env vars, or
112-
// a source wasn't detected (it defaulted to ImdsV1) and MtlsPop was NOT requested,
113-
// use the source.
114-
// (don't trigger the ImdsV2 probe endpoint if MtlsPop was NOT requested)
115-
if (source != ManagedIdentitySource.DefaultToImds || !isMtlsPopRequested)
112+
ManagedIdentitySource source = GetManagedIdentitySourceNoImds(requestContext.Logger);
113+
if (source != ManagedIdentitySource.None)
116114
{
117115
s_sourceName = source;
118116
return source;
119117
}
120118

121-
// Otherwise, probe IMDSv2
122-
var response = await ImdsV2ManagedIdentitySource.GetCsrMetadataAsync(requestContext, probeMode: true).ConfigureAwait(false);
123-
if (response != null)
119+
// skip the ImdsV2 probe if MtlsPop was NOT requested
120+
if (isMtlsPopRequested)
121+
{
122+
var imdsV2Response = await ImdsManagedIdentitySource.ProbeImdsEndpointAsync(requestContext, ImdsVersion.V2, cancellationToken).ConfigureAwait(false);
123+
if (imdsV2Response)
124+
{
125+
requestContext.Logger.Info("[Managed Identity] ImdsV2 detected.");
126+
s_sourceName = ManagedIdentitySource.ImdsV2;
127+
return s_sourceName;
128+
}
129+
}
130+
else
131+
{
132+
requestContext.Logger.Info("[Managed Identity] Mtls Pop was not requested; skipping ImdsV2 probe.");
133+
}
134+
135+
var imdsV1Response = await ImdsManagedIdentitySource.ProbeImdsEndpointAsync(requestContext, ImdsVersion.V1, cancellationToken).ConfigureAwait(false);
136+
if (imdsV1Response)
124137
{
125-
requestContext.Logger.Info("[Managed Identity] ImdsV2 detected.");
126-
s_sourceName = ManagedIdentitySource.ImdsV2;
138+
requestContext.Logger.Info("[Managed Identity] ImdsV1 detected.");
139+
s_sourceName = ManagedIdentitySource.Imds;
127140
return s_sourceName;
128141
}
129142

130-
requestContext.Logger.Info("[Managed Identity] IMDSv2 probe failed. Defaulting to IMDSv1.");
131-
s_sourceName = ManagedIdentitySource.DefaultToImds;
143+
requestContext.Logger.Info($"[Managed Identity] {MsalErrorMessage.ManagedIdentityAllSourcesUnavailable}");
144+
s_sourceName = ManagedIdentitySource.None;
132145
return s_sourceName;
133146
}
134147

135-
// Detect managed identity source based on the availability of environment variables.
136-
// The result of this method is not cached because reading environment variables is cheap.
137-
// This method is perf sensitive any changes should be benchmarked.
138-
internal static ManagedIdentitySource GetManagedIdentitySourceNoImdsV2(ILoggerAdapter logger = null)
148+
/// <summary>
149+
/// Detects the managed identity source based on the availability of environment variables.
150+
/// It does not probe IMDS, but it checks for all other sources.
151+
/// This method does not cache its result, as reading environment variables is inexpensive.
152+
/// It is performance sensitive; any changes should be benchmarked.
153+
/// </summary>
154+
/// <param name="logger">Optional logger for diagnostic output.</param>
155+
/// <returns>
156+
/// The detected <see cref="ManagedIdentitySource"/> based on environment variables.
157+
/// Returns <c>ManagedIdentitySource.None</c> if no environment-based source is detected.
158+
/// </returns>
159+
internal static ManagedIdentitySource GetManagedIdentitySourceNoImds(ILoggerAdapter logger = null)
139160
{
140161
string identityEndpoint = EnvironmentVariables.IdentityEndpoint;
141162
string identityHeader = EnvironmentVariables.IdentityHeader;
@@ -177,7 +198,7 @@ internal static ManagedIdentitySource GetManagedIdentitySourceNoImdsV2(ILoggerAd
177198
}
178199
else
179200
{
180-
return ManagedIdentitySource.DefaultToImds;
201+
return ManagedIdentitySource.None;
181202
}
182203
}
183204

src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentitySource.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ public enum ManagedIdentitySource
4848
/// Indicates that the source is defaulted to IMDS since no environment variables are set.
4949
/// This is used to detect the managed identity source.
5050
/// </summary>
51+
[Obsolete("In use only to support the now obsolete GetManagedIdentitySource API. Will be removed in a future version. Use GetManagedIdentitySourceAsync instead.")]
5152
DefaultToImds,
5253

5354
/// <summary>

0 commit comments

Comments
 (0)