diff --git a/libraries/Microsoft.Bot.Builder.Dialogs/SkillDialog.cs b/libraries/Microsoft.Bot.Builder.Dialogs/SkillDialog.cs index 25016a19a1..0ef35f545c 100644 --- a/libraries/Microsoft.Bot.Builder.Dialogs/SkillDialog.cs +++ b/libraries/Microsoft.Bot.Builder.Dialogs/SkillDialog.cs @@ -249,6 +249,9 @@ private async Task SendToSkillAsync(ITurnContext context, Activity act await DialogOptions.ConversationState.SaveChangesAsync(context, true, cancellationToken).ConfigureAwait(false); var skillInfo = DialogOptions.Skill; + + DialogOptions.SkillClient.AddDefaultHeaders(); + var response = await DialogOptions.SkillClient.PostActivityAsync(DialogOptions.BotId, skillInfo.AppId, skillInfo.SkillEndpoint, DialogOptions.SkillHostEndpoint, skillConversationId, activity, cancellationToken).ConfigureAwait(false); // Inspect the skill response status diff --git a/libraries/Microsoft.Bot.Builder/CloudAdapterBase.cs b/libraries/Microsoft.Bot.Builder/CloudAdapterBase.cs index ad10362487..fb05222b65 100644 --- a/libraries/Microsoft.Bot.Builder/CloudAdapterBase.cs +++ b/libraries/Microsoft.Bot.Builder/CloudAdapterBase.cs @@ -259,7 +259,7 @@ protected virtual ConnectorFactory GetStreamingConnectorFactory(Activity activit /// The method to call for the resulting bot turn. /// Cancellation token. /// A task that represents the work queued to execute. - protected async Task ProcessProactiveAsync(ClaimsIdentity claimsIdentity, Activity continuationActivity, string audience, BotCallbackHandler callback, CancellationToken cancellationToken) + protected virtual async Task ProcessProactiveAsync(ClaimsIdentity claimsIdentity, Activity continuationActivity, string audience, BotCallbackHandler callback, CancellationToken cancellationToken) { Logger.LogInformation($"ProcessProactiveAsync for Conversation Id: {continuationActivity.Conversation.Id}"); diff --git a/libraries/Microsoft.Bot.Connector/Authentication/BotFrameworkClientImpl.cs b/libraries/Microsoft.Bot.Connector/Authentication/BotFrameworkClientImpl.cs index c86bf857cb..848bf6292b 100644 --- a/libraries/Microsoft.Bot.Connector/Authentication/BotFrameworkClientImpl.cs +++ b/libraries/Microsoft.Bot.Connector/Authentication/BotFrameworkClientImpl.cs @@ -121,6 +121,11 @@ public async override Task> PostActivityAsync(string fromBo } } + public override void AddDefaultHeaders() + { + ConnectorClient.AddDefaultRequestHeaders(_httpClient); + } + protected override void Dispose(bool disposing) { if (_disposed) diff --git a/libraries/Microsoft.Bot.Connector/Bot.Builder/BotFrameworkClient.cs b/libraries/Microsoft.Bot.Connector/Bot.Builder/BotFrameworkClient.cs index 1bec6f7722..05e2851f00 100644 --- a/libraries/Microsoft.Bot.Connector/Bot.Builder/BotFrameworkClient.cs +++ b/libraries/Microsoft.Bot.Connector/Bot.Builder/BotFrameworkClient.cs @@ -45,6 +45,13 @@ public async virtual Task PostActivityAsync(string fromBotId, st /// Async task with optional invokeResponse. public abstract Task> PostActivityAsync(string fromBotId, string toBotId, Uri toUrl, Uri serviceUrl, string conversationId, Activity activity, CancellationToken cancellationToken = default); + /// + /// Allows to add default headers to the HTTP client after the creation of the instance. + /// + public virtual void AddDefaultHeaders() + { + } + /// public void Dispose() { diff --git a/libraries/Microsoft.Bot.Connector/ConnectorClientEx.cs b/libraries/Microsoft.Bot.Connector/ConnectorClientEx.cs index 40062fe64b..6b2a5275f5 100644 --- a/libraries/Microsoft.Bot.Connector/ConnectorClientEx.cs +++ b/libraries/Microsoft.Bot.Connector/ConnectorClientEx.cs @@ -250,6 +250,19 @@ public static void AddDefaultRequestHeaders(HttpClient httpClient) } } + var headersToPropagate = HeaderPropagation.HeadersToPropagate; + + if (headersToPropagate != null && headersToPropagate.Count > 0) + { + foreach (var header in headersToPropagate) + { + if (!httpClient.DefaultRequestHeaders.Contains(header.Key)) + { + httpClient.DefaultRequestHeaders.Add(header.Key, header.Value.ToArray()); + } + } + } + httpClient.DefaultRequestHeaders.ExpectContinue = false; var jsonAcceptHeader = new MediaTypeWithQualityHeaderValue("*/*"); diff --git a/libraries/Microsoft.Bot.Connector/HeaderPropagation.cs b/libraries/Microsoft.Bot.Connector/HeaderPropagation.cs new file mode 100644 index 0000000000..70c957a547 --- /dev/null +++ b/libraries/Microsoft.Bot.Connector/HeaderPropagation.cs @@ -0,0 +1,88 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Threading; +using Microsoft.Extensions.Primitives; + +namespace Microsoft.Bot.Connector +{ + /// + /// Class to handle header propagation from incoming request to outgoing request. + /// + public static class HeaderPropagation + { + private static readonly AsyncLocal> _requestHeaders = new (); + + private static readonly AsyncLocal> _headersToPropagate = new (); + + /// + /// Gets or sets the headers from an incoming request. + /// + /// The headers from an incoming request. + public static IDictionary RequestHeaders + { + get => _requestHeaders.Value ??= new Dictionary(StringComparer.OrdinalIgnoreCase); + set => _requestHeaders.Value = value; + } + + /// + /// Gets or sets the selected headers for propagation. + /// + /// The selected headers for propagation. + public static IDictionary HeadersToPropagate + { + get => _headersToPropagate.Value ??= new Dictionary(StringComparer.OrdinalIgnoreCase); + set => _headersToPropagate.Value = value; + } + + /// + /// Filters the request's headers to include only those relevant for propagation. + /// + /// The chosen headers to propagate. + /// The filtered headers. + public static IDictionary FilterHeaders(HeaderPropagationEntryCollection headerFilter) + { + // We propagate the X-Ms-Correlation-Id header by default. + headerFilter.Propagate("X-Ms-Correlation-Id"); + + var filteredHeaders = new Dictionary(StringComparer.OrdinalIgnoreCase); + + foreach (var filter in headerFilter.Entries) + { + if (RequestHeaders.TryGetValue(filter.Key, out var value)) + { + switch (filter.Action) + { + case HeaderPropagationEntryAction.Add: + break; + case HeaderPropagationEntryAction.Append: + filteredHeaders[filter.Key] = StringValues.Concat(value, filter.Value); + break; + case HeaderPropagationEntryAction.Override: + filteredHeaders.Add(filter.Key, filter.Value); + break; + case HeaderPropagationEntryAction.Propagate: + filteredHeaders.Add(filter.Key, value); + break; + } + } + else + { + switch (filter.Action) + { + case HeaderPropagationEntryAction.Add: + filteredHeaders.Add(filter.Key, filter.Value); + break; + case HeaderPropagationEntryAction.Override: + filteredHeaders.Add(filter.Key, filter.Value); + break; + } + } + } + + return filteredHeaders; + } + } +} diff --git a/libraries/Microsoft.Bot.Connector/HeaderPropagationEntry.cs b/libraries/Microsoft.Bot.Connector/HeaderPropagationEntry.cs new file mode 100644 index 0000000000..b5a5300cc2 --- /dev/null +++ b/libraries/Microsoft.Bot.Connector/HeaderPropagationEntry.cs @@ -0,0 +1,62 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Runtime.Serialization; +using Microsoft.Extensions.Primitives; + +namespace Microsoft.Bot.Connector +{ + /// + /// Represents the action to perform with the header entry. + /// + public enum HeaderPropagationEntryAction + { + /// + /// Adds a new header entry to the outgoing request. + /// + [EnumMember(Value = "add")] + Add, + + /// + /// Appends a new header value to an existing key in the outgoing request. + /// + [EnumMember(Value = "append")] + Append, + + /// + /// Propagates the header entry from the incoming request to the outgoing request without modifications. + /// + [EnumMember(Value = "propagate")] + Propagate, + + /// + /// Overrides an existing header entry in the outgoing request. + /// + [EnumMember(Value = "override")] + Override + } + + /// + /// Represents a single header entry used for header propagation. + /// + public class HeaderPropagationEntry + { + /// + /// Gets or sets the key of the header entry. + /// + /// Key of the header entry. + public string Key { get; set; } = string.Empty; + + /// + /// Gets or sets the value of the header entry. + /// + /// Value of the header entry. + public StringValues Value { get; set; } = new StringValues(string.Empty); + + /// + /// Gets or sets the action of the header entry (Add, Append, Override or Propagate). + /// + /// Action of the header entry. + public HeaderPropagationEntryAction Action { get; set; } = HeaderPropagationEntryAction.Propagate; + } +} diff --git a/libraries/Microsoft.Bot.Connector/HeaderPropagationEntryCollection.cs b/libraries/Microsoft.Bot.Connector/HeaderPropagationEntryCollection.cs new file mode 100644 index 0000000000..4b66c85b17 --- /dev/null +++ b/libraries/Microsoft.Bot.Connector/HeaderPropagationEntryCollection.cs @@ -0,0 +1,102 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.Extensions.Primitives; + +namespace Microsoft.Bot.Connector +{ + /// + /// Represents a collection of the header entries configured to be propagated to outgoing requests. + /// + public class HeaderPropagationEntryCollection + { + private readonly Dictionary _entries = new (StringComparer.OrdinalIgnoreCase); + + /// + /// Gets the collection of header entries to be propagated to outgoing requests. + /// + /// The collection of header entries. + public List Entries => _entries.Values.ToList(); + + /// + /// Attempts to add a new header entry to the collection. + /// + /// + /// If the key already exists, it will be ignored. + /// + /// The key of the element to add. + /// The value to add for the specified key. + public void Add(string key, StringValues value) + { + _entries[key] = new HeaderPropagationEntry + { + Key = key, + Value = value, + Action = HeaderPropagationEntryAction.Add + }; + } + + /// + /// Appends a new header value to an existing key. + /// + /// + /// If the key does not exist, it will be ignored. + /// + /// The key of the element to append the value. + /// The value to append for the specified key. + public void Append(string key, StringValues value) + { + StringValues newValue; + + if (_entries.TryGetValue(key, out var entry)) + { + // If the key already exists, append the new value to the existing one. + newValue = StringValues.Concat(entry.Value, value); + } + + _entries[key] = new HeaderPropagationEntry + { + Key = key, + Value = !StringValues.IsNullOrEmpty(newValue) ? newValue : value, + Action = HeaderPropagationEntryAction.Append + }; + } + + /// + /// Propagates the incoming request header value to outgoing requests without modifications. + /// + /// + /// If the key does not exist, it will be ignored. + /// + /// The key of the element to propagate. + public void Propagate(string key) + { + _entries[key] = new HeaderPropagationEntry + { + Key = key, + Action = HeaderPropagationEntryAction.Propagate + }; + } + + /// + /// Overrides the header value of an existing key. + /// + /// + /// If the key does not exist, it will add it. + /// + /// The key of the element to override. + /// The value to override in the specified key. + public void Override(string key, StringValues value) + { + _entries[key] = new HeaderPropagationEntry + { + Key = key, + Value = value, + Action = HeaderPropagationEntryAction.Override + }; + } + } +} diff --git a/libraries/Microsoft.Bot.Connector/Teams/TeamsHeaderPropagation.cs b/libraries/Microsoft.Bot.Connector/Teams/TeamsHeaderPropagation.cs new file mode 100644 index 0000000000..55a907c7c3 --- /dev/null +++ b/libraries/Microsoft.Bot.Connector/Teams/TeamsHeaderPropagation.cs @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using Microsoft.Extensions.Primitives; + +namespace Microsoft.Bot.Connector.Teams +{ + /// + /// Instantiate this class to set the headers to propagate from incoming request to outgoing request. + /// + public static class TeamsHeaderPropagation + { + /// + /// Returns the headers to propagate from incoming request to outgoing request. + /// + /// The collection of headers to propagate. + public static HeaderPropagationEntryCollection GetHeadersToPropagate() + { + // Propagate headers to the outgoing request by adding them to the HeaderPropagationEntryCollection. + // For example: + var headersToPropagate = new HeaderPropagationEntryCollection(); + + //headersToPropagate.Propagate("X-Ms-Teams-Id"); + //headersToPropagate.Add("X-Ms-Teams-Custom", new StringValues("Custom-Value")); + //headersToPropagate.Append("X-Ms-Teams-Channel", new StringValues("-SubChannel-Id")); + //headersToPropagate.Override("X-Ms-Other", new StringValues("new-value")); + + return headersToPropagate; + } + } +} diff --git a/libraries/integration/Microsoft.Bot.Builder.Integration.AspNet.Core/CloudAdapter.cs b/libraries/integration/Microsoft.Bot.Builder.Integration.AspNet.Core/CloudAdapter.cs index a94875571b..817edb4f54 100644 --- a/libraries/integration/Microsoft.Bot.Builder.Integration.AspNet.Core/CloudAdapter.cs +++ b/libraries/integration/Microsoft.Bot.Builder.Integration.AspNet.Core/CloudAdapter.cs @@ -3,6 +3,8 @@ using System; using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Linq; using System.Net; using System.Net.Http; using System.Net.WebSockets; @@ -14,10 +16,12 @@ using Microsoft.Bot.Connector; using Microsoft.Bot.Connector.Authentication; using Microsoft.Bot.Connector.Streaming.Application; +using Microsoft.Bot.Connector.Teams; using Microsoft.Bot.Schema; using Microsoft.Bot.Streaming; using Microsoft.Extensions.Configuration; using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Primitives; namespace Microsoft.Bot.Builder.Integration.AspNet.Core { @@ -27,6 +31,7 @@ namespace Microsoft.Bot.Builder.Integration.AspNet.Core public class CloudAdapter : CloudAdapterBase, IBotFrameworkHttpAdapter { private readonly ConcurrentDictionary _streamingConnections = new ConcurrentDictionary(); + private readonly IStorage _storage; /// /// Initializes a new instance of the class. (Public cloud. No auth. For testing.) @@ -41,11 +46,14 @@ public CloudAdapter() /// /// The this adapter should use. /// The implementation this adapter should use. + /// The implementation this adapter should use for header propagation. public CloudAdapter( BotFrameworkAuthentication botFrameworkAuthentication, - ILogger logger = null) + ILogger logger = null, + IStorage storage = null) : base(botFrameworkAuthentication, logger) { + _storage = storage ?? new MemoryStorage(); } /// @@ -54,12 +62,15 @@ public CloudAdapter( /// The instance. /// The this adapter should use. /// The implementation this adapter should use. + /// The implementation this adapter should use for header propagation. public CloudAdapter( IConfiguration configuration, IHttpClientFactory httpClientFactory = null, - ILogger logger = null) + ILogger logger = null, + IStorage storage = null) : this(new ConfigurationBotFrameworkAuthentication(configuration, httpClientFactory: httpClientFactory, logger: logger), logger) { + _storage = storage ?? new MemoryStorage(); } /// @@ -98,6 +109,14 @@ public async Task ProcessAsync(HttpRequest httpRequest, HttpResponse httpRespons return; } + var filteredHeaders = GetPropagationHeaders(httpRequest, activity); + + if (activity.Conversation?.Id != null) + { + // Store headers to be retrieved in case of proactive messages. + StoreFilteredHeaders(filteredHeaders, activity.Conversation.Id, cancellationToken); + } + // Grab the auth header from the inbound http request var authHeader = httpRequest.Headers["Authorization"]; @@ -106,6 +125,12 @@ public async Task ProcessAsync(HttpRequest httpRequest, HttpResponse httpRespons // Write the response, potentially serializing the InvokeResponse await HttpHelper.WriteResponseAsync(httpResponse, invokeResponse).ConfigureAwait(false); + + if (activity.Type == ActivityTypes.EndOfConversation && activity.Conversation?.Id != null) + { + // Delete stored headers to avoid memory bloat. + DeleteStoredHeaders(activity.Conversation.Id, cancellationToken); + } } else { @@ -206,6 +231,83 @@ protected virtual StreamingConnection CreateWebSocketConnection(WebSocket socket return new WebSocketStreamingConnection(socket, logger); } + /// + protected override async Task ProcessProactiveAsync(ClaimsIdentity claimsIdentity, Activity continuationActivity, string audience, BotCallbackHandler callback, CancellationToken cancellationToken) + { + // Retrieve the headers from IStorage + if (continuationActivity.Conversation?.Id != null) + { + var storageKey = $"headers-{continuationActivity.Conversation.Id}"; + var readAttempts = 3; + var delay = TimeSpan.FromMilliseconds(100); + + while (readAttempts > 0) + { + try + { + var storedData = await _storage.ReadAsync([storageKey], cancellationToken).ConfigureAwait(false); + + if (storedData.TryGetValue(storageKey, out var headersObject) && headersObject is Dictionary serializedHeaders) + { + var headers = serializedHeaders.ToDictionary( + kvp => kvp.Key, + kvp => new StringValues(kvp.Value)); + + HeaderPropagation.HeadersToPropagate = headers; + break; + } + else + { + // No headers found, retry. + await Task.Delay(delay, cancellationToken).ConfigureAwait(false); + readAttempts--; + } + } + catch (Exception ex) + { + Logger.LogError(ex, "Failed to read headers from storage."); + readAttempts--; + } + } + } + + await base.ProcessProactiveAsync(claimsIdentity, continuationActivity, audience, callback, cancellationToken); + } + + /// + /// Get the headers to propagate from the the incoming request. + /// + /// The incoming request to get the headers from. + /// The activity contained in the request. + /// The headers to be propagated to outgoing requests. + private IDictionary GetPropagationHeaders(HttpRequest httpRequest, IActivity activity) + { + // Read the headers from the request. + var headers = new Dictionary(StringComparer.OrdinalIgnoreCase); + + foreach (var header in httpRequest.Headers) + { + headers[header.Key] = header.Value; + } + + HeaderPropagation.RequestHeaders = headers; + + // Look for the selected headers to propagate. + var headersCollection = new HeaderPropagationEntryCollection(); + + // TODO: If a channel implements a static class to configure header propagation, add it to this block. + //if (activity.ChannelId == Channels.Msteams) + //{ + // headersCollection = TeamsHeaderPropagation.GetHeadersToPropagate(); + //} + + var filteredHeaders = HeaderPropagation.FilterHeaders(headersCollection); + + HeaderPropagation.HeadersToPropagate = filteredHeaders; + + return filteredHeaders; + } + private async Task ConnectAsync(HttpRequest httpRequest, IBot bot, CancellationToken cancellationToken) { Logger.LogInformation($"Received request for web socket connect."); @@ -236,6 +338,51 @@ private async Task ConnectAsync(HttpRequest httpRequest, IBot bot, CancellationT } } + private void StoreFilteredHeaders(IDictionary filteredHeaders, string conversationId, CancellationToken cancellationToken) + { + var serializedHeaders = filteredHeaders.ToDictionary( + kvp => kvp.Key, + kvp => kvp.Value.ToArray()); + + var storageKey = $"headers-{conversationId}"; + var storageData = new Dictionary + { + { storageKey, serializedHeaders } + }; + + // fire and forget the write operation to avoid blocking the request. + _ = Task.Run( + async () => + { + try + { + await _storage.WriteAsync(storageData, cancellationToken).ConfigureAwait(false); + } + catch (Exception ex) + { + Logger.LogError(ex, "Failed to write headers to storage."); + } + }, cancellationToken); + } + + private void DeleteStoredHeaders(string conversationId, CancellationToken cancellationToken) + { + // fire and forget the delete operation to avoid blocking the request. + _ = Task.Run( + async () => + { + try + { + var storageKey = $"headers-{conversationId}"; + await _storage.DeleteAsync([storageKey], cancellationToken).ConfigureAwait(false); + } + catch (Exception ex) + { + Logger.LogError(ex, "Failed to delete headers after EndOfConversation"); + } + }, cancellationToken); + } + private class StreamingActivityProcessor : IStreamingActivityProcessor, IDisposable { private readonly AuthenticateRequestResult _authenticateRequestResult; diff --git a/tests/Microsoft.Bot.Connector.Tests/HeaderPropagationTests.cs b/tests/Microsoft.Bot.Connector.Tests/HeaderPropagationTests.cs new file mode 100644 index 0000000000..5f2207734a --- /dev/null +++ b/tests/Microsoft.Bot.Connector.Tests/HeaderPropagationTests.cs @@ -0,0 +1,116 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Collections.Generic; +using Microsoft.Extensions.Primitives; +using Xunit; + +namespace Microsoft.Bot.Connector.Tests +{ + [Collection("Non-Parallel Collection")] // Ensure this test runs in a single-threaded context to avoid issues with static dictionary. + public class HeaderPropagationTests + { + public HeaderPropagationTests() + { + HeaderPropagation.HeadersToPropagate = new Dictionary(); + } + + [Fact] + public void HeaderPropagation_ShouldFilterHeaders() + { + // Arrange + HeaderPropagation.RequestHeaders = new Dictionary + { + { "x-custom-header-1", new StringValues("Value-1") }, + { "x-custom-header-2", new StringValues("Value-2") }, + { "x-custom-header-3", new StringValues("Value-3") } + }; + + var headersToPropagate = new HeaderPropagationEntryCollection(); + + headersToPropagate.Add("x-custom-header", "custom-value"); + headersToPropagate.Propagate("x-custom-header-1"); + headersToPropagate.Override("x-custom-header-2", "new-value"); + headersToPropagate.Append("x-custom-header-3", "extra-value"); + + // Act + var filteredHeaders = HeaderPropagation.FilterHeaders(headersToPropagate); + + // Assert + Assert.Equal(4, filteredHeaders.Count); + Assert.Equal("custom-value", filteredHeaders["x-custom-header"]); + Assert.Equal("Value-1", filteredHeaders["x-custom-header-1"]); + Assert.Equal("new-value", filteredHeaders["x-custom-header-2"]); + Assert.Equal("Value-3,extra-value", filteredHeaders["x-custom-header-3"]); + } + + [Fact] + public void HeaderPropagation_ShouldAppendMultipleValues() + { + // Arrange + HeaderPropagation.RequestHeaders = new Dictionary + { + { "User-Agent", new StringValues("Value-1") } + }; + + var headersToPropagate = new HeaderPropagationEntryCollection(); + + headersToPropagate.Append("User-Agent", "extra-value-1"); + headersToPropagate.Append("User-Agent", "extra-value-2"); + + // Act + var filteredHeaders = HeaderPropagation.FilterHeaders(headersToPropagate); + + // Assert + Assert.Single(filteredHeaders); + Assert.Equal("Value-1,extra-value-1,extra-value-2", filteredHeaders["User-Agent"]); + } + + [Fact] + public void HeaderPropagation_MultipleAdd_ShouldKeepLastValue() + { + // Arrange + HeaderPropagation.RequestHeaders = new Dictionary(); + + var headersToPropagate = new HeaderPropagationEntryCollection(); + + headersToPropagate.Add("x-custom-header-1", "value-1"); + headersToPropagate.Add("x-custom-header-1", "value-2"); + + // Act + var filteredHeaders = HeaderPropagation.FilterHeaders(headersToPropagate); + + // Assert + Assert.Single(filteredHeaders); + Assert.Equal("value-2", filteredHeaders["x-custom-header-1"]); + } + + [Fact] + public void HeaderPropagation_MultipleOverride_ShouldKeepLastValue() + { + // Arrange + HeaderPropagation.RequestHeaders = new Dictionary + { + { "x-custom-header-1", new StringValues("Value-1") } + }; + + var headersToPropagate = new HeaderPropagationEntryCollection(); + headersToPropagate.Override("x-custom-header-1", "new-value-1"); + headersToPropagate.Override("x-custom-header-1", "new-value-2"); + + // Act + var filteredHeaders = HeaderPropagation.FilterHeaders(headersToPropagate); + + // Assert + Assert.Single(filteredHeaders); + Assert.Equal("new-value-2", filteredHeaders["x-custom-header-1"]); + } + } + + [CollectionDefinition("Non-Parallel Collection", DisableParallelization = true)] +#pragma warning disable SA1402 // File may only contain a single type + public class NonParallelCollectionDefinition + { + } +#pragma warning restore SA1402 // File may only contain a single type +} diff --git a/tests/integration/Microsoft.Bot.Builder.Integration.AspNet.Core.Tests/CloudAdapterTests.cs b/tests/integration/Microsoft.Bot.Builder.Integration.AspNet.Core.Tests/CloudAdapterTests.cs index 5e6798c83a..6d1a788bd9 100644 --- a/tests/integration/Microsoft.Bot.Builder.Integration.AspNet.Core.Tests/CloudAdapterTests.cs +++ b/tests/integration/Microsoft.Bot.Builder.Integration.AspNet.Core.Tests/CloudAdapterTests.cs @@ -41,13 +41,13 @@ public class CloudAdapterTests public async Task BasicMessageActivity() { // Arrange - var headerDictionaryMock = new Mock(); - headerDictionaryMock.Setup(h => h[It.Is(v => v == "Authorization")]).Returns(null); - var httpRequestMock = new Mock(); httpRequestMock.Setup(r => r.Method).Returns(HttpMethods.Post); httpRequestMock.Setup(r => r.Body).Returns(CreateMessageActivityStream()); - httpRequestMock.Setup(r => r.Headers).Returns(headerDictionaryMock.Object); + httpRequestMock.Setup(r => r.Headers).Returns(new HeaderDictionary + { + { "Authorization", StringValues.Empty } + }); var httpResponseMock = new Mock(); @@ -66,13 +66,13 @@ public async Task BasicMessageActivity() public async Task InvokeActivity() { // Arrange - var headerDictionaryMock = new Mock(); - headerDictionaryMock.Setup(h => h[It.Is(v => v == "Authorization")]).Returns(null); - var httpRequestMock = new Mock(); httpRequestMock.Setup(r => r.Method).Returns(HttpMethods.Post); httpRequestMock.Setup(r => r.Body).Returns(CreateInvokeActivityStream()); - httpRequestMock.Setup(r => r.Headers).Returns(headerDictionaryMock.Object); + httpRequestMock.Setup(r => r.Headers).Returns(new HeaderDictionary + { + { "Authorization", StringValues.Empty } + }); var response = new MemoryStream(); var httpResponseMock = new Mock(); @@ -353,7 +353,7 @@ public async Task CanContinueConversationOverWebSocketAsync() var nullUrlProcessRequest = adapter.ProcessAsync(nullUrlHttpRequest.Object, nullUrlHttpResponse.Object, bot.Object, CancellationToken.None); var processRequest = adapter.ProcessAsync(httpRequest.Object, httpResponse.Object, bot.Object, CancellationToken.None); - var validContinuation = adapter.ContinueConversationAsync( + await adapter.ContinueConversationAsync( authResult.ClaimsIdentity, validActivity, (turn, cancellationToken) => @@ -368,8 +368,8 @@ public async Task CanContinueConversationOverWebSocketAsync() }, CancellationToken.None); - var invalidContinuation = adapter.ContinueConversationAsync( - authResult.ClaimsIdentity, invalidActivity, (turn, cancellationToken) => Task.CompletedTask, CancellationToken.None); + await Assert.ThrowsAsync(() => adapter.ContinueConversationAsync( + authResult.ClaimsIdentity, invalidActivity, (turn, cancellationToken) => Task.CompletedTask, CancellationToken.None)); continueConversationWaiter.Set(); await nullUrlProcessRequest; @@ -378,24 +378,19 @@ public async Task CanContinueConversationOverWebSocketAsync() // Assert Assert.True(processRequest.IsCompletedSuccessfully); Assert.True(verifiedValidContinuation); - Assert.True(validContinuation.IsCompletedSuccessfully); - Assert.Null(validContinuation.Exception); - Assert.True(invalidContinuation.IsFaulted); - Assert.NotEmpty(invalidContinuation.Exception.InnerExceptions); - Assert.True(invalidContinuation.Exception.InnerExceptions[0] is ApplicationException); } [Fact] public async Task MessageActivityWithHttpClient() { // Arrange - var headerDictionaryMock = new Mock(); - headerDictionaryMock.Setup(h => h[It.Is(v => v == "Authorization")]).Returns(null); - var httpRequestMock = new Mock(); httpRequestMock.Setup(r => r.Method).Returns(HttpMethods.Post); httpRequestMock.Setup(r => r.Body).Returns(CreateMessageActivityStream()); - httpRequestMock.Setup(r => r.Headers).Returns(headerDictionaryMock.Object); + httpRequestMock.Setup(r => r.Headers).Returns(new HeaderDictionary + { + { "Authorization", StringValues.Empty } + }); var httpResponseMock = new Mock(); @@ -474,13 +469,13 @@ public async Task BadRequest() public async Task InjectCloudEnvironment() { // Arrange - var headerDictionaryMock = new Mock(); - headerDictionaryMock.Setup(h => h[It.Is(v => v == "Authorization")]).Returns(null); - var httpRequestMock = new Mock(); httpRequestMock.Setup(r => r.Method).Returns(HttpMethods.Post); httpRequestMock.Setup(r => r.Body).Returns(CreateMessageActivityStream()); - httpRequestMock.Setup(r => r.Headers).Returns(headerDictionaryMock.Object); + httpRequestMock.Setup(r => r.Headers).Returns(new HeaderDictionary + { + { "Authorization", StringValues.Empty } + }); var httpResponseMock = new Mock(); @@ -527,13 +522,13 @@ public async Task CloudAdapterProvidesUserTokenClient() string relatesToActivityId = "relatesToActivityId"; string connectionName = "connectionName"; - var headerDictionaryMock = new Mock(); - headerDictionaryMock.Setup(h => h[It.Is(v => v == "Authorization")]).Returns(null); - var httpRequestMock = new Mock(); httpRequestMock.Setup(r => r.Method).Returns(HttpMethods.Post); httpRequestMock.Setup(r => r.Body).Returns(CreateMessageActivityStream(userId, channelId, conversationId, recipientId, relatesToActivityId)); - httpRequestMock.Setup(r => r.Headers).Returns(headerDictionaryMock.Object); + httpRequestMock.Setup(r => r.Headers).Returns(new HeaderDictionary + { + { "Authorization", StringValues.Empty } + }); var httpResponseMock = new Mock(); @@ -609,14 +604,13 @@ public async Task CloudAdapterConnectorFactory() // this is just a basic test to verify the wire-up of a ConnectorFactory in the CloudAdapter // Arrange - - var headerDictionaryMock = new Mock(); - headerDictionaryMock.Setup(h => h[It.Is(v => v == "Authorization")]).Returns(null); - var httpRequestMock = new Mock(); httpRequestMock.Setup(r => r.Method).Returns(HttpMethods.Post); httpRequestMock.Setup(r => r.Body).Returns(CreateMessageActivityStream()); - httpRequestMock.Setup(r => r.Headers).Returns(headerDictionaryMock.Object); + httpRequestMock.Setup(r => r.Headers).Returns(new HeaderDictionary + { + { "Authorization", StringValues.Empty } + }); var httpResponseMock = new Mock(); @@ -809,7 +803,6 @@ public async Task CloudAdapterCreateConversation() public async Task ExpiredTokenShouldThrowUnauthorizedAccessException() { // Arrange - var headerDictionaryMock = new Mock(); // Expired token with removed AppID // This token will be validated against real endpoint https://login.microsoftonline.com/common/discovery/v2.0/keys @@ -826,12 +819,13 @@ public async Task ExpiredTokenShouldThrowUnauthorizedAccessException() // - delete the app var token = "Bearer eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsIng1dCI6Ii1LSTNROW5OUjdiUm9meG1lWm9YcWJIWkdldyIsImtpZCI6Ii1LSTNROW5OUjdiUm9meG1lWm9YcWJIWkdldyJ9.eyJhdWQiOiJodHRwczovL2FwaS5ib3RmcmFtZXdvcmsuY29tIiwiaXNzIjoiaHR0cHM6Ly9zdHMud2luZG93cy5uZXQvZDZkNDk0MjAtZjM5Yi00ZGY3LWExZGMtZDU5YTkzNTg3MWRiLyIsImlhdCI6MTY5Mjg3MDMwMiwibmJmIjoxNjkyODcwMzAyLCJleHAiOjE2OTI5NTcwMDIsImFpbyI6IkUyRmdZUGhhdFZ6czVydGFFYTlWbDN2ZnIyQ2JBZ0E9IiwiYXBwaWQiOiIxNWYwMTZmZS00ODhjLTQwZTktOWNiZS00Yjk0OGY5OGUyMmMiLCJhcHBpZGFjciI6IjEiLCJpZHAiOiJodHRwczovL3N0cy53aW5kb3dzLm5ldC9kNmQ0OTQyMC1mMzliLTRkZjctYTFkYy1kNTlhOTM1ODcxZGIvIiwicmgiOiIwLkFXNEFJSlRVMXB2ejkwMmgzTldhazFoeDIwSXpMWTBwejFsSmxYY09EcS05RnJ4dUFBQS4iLCJ0aWQiOiJkNmQ0OTQyMC1mMzliLTRkZjctYTFkYy1kNTlhOTM1ODcxZGIiLCJ1dGkiOiJkenVwa1dWd2FVT2x1RldkbnlvLUFBIiwidmVyIjoiMS4wIn0.sbQH997Q2GDKiiYd6l5MIz_XNfXypJd6zLY9xjtvEgXMBB0x0Vu3fv9W0nM57_ZipQiZDTZuSQA5BE30KBBwU-ZVqQ7MgiTkmE9eF6Ngie_5HwSr9xMK3EiDghHiOP9pIj3oEwGOSyjR5L9n-7tLSdUbKVyV14nS8OQtoPd1LZfoZI3e7tVu3vx8Lx3KzudanXX8Vz7RKaYndj3RyRi4wEN5hV9ab40d7fQsUzygFd5n_PXC2rs0OhjZJzjCOTC0VLQEn1KwiTkSH1E-OSzkrMltn1sbhD2tv_H-4rqQd51vAEJ7esC76qQjz_pfDRLs6T2jvJyhd5MZrN_MT0TqlA"; - headerDictionaryMock.Setup(h => h[It.Is(v => v == "Authorization")]).Returns((_) => token); - var httpRequestMock = new Mock(); httpRequestMock.Setup(r => r.Method).Returns(HttpMethods.Post); httpRequestMock.Setup(r => r.Body).Returns(CreateInvokeActivityStream()); - httpRequestMock.Setup(r => r.Headers).Returns(headerDictionaryMock.Object); + httpRequestMock.Setup(r => r.Headers).Returns(new HeaderDictionary + { + { "Authorization", token } + }); var response = new MemoryStream(); var httpResponseMock = new Mock().SetupAllProperties();