diff --git a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_CustomImplementation/Program.cs b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_CustomImplementation/Program.cs index 8f1039251d..66e3ea2a52 100644 --- a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_CustomImplementation/Program.cs +++ b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_CustomImplementation/Program.cs @@ -28,10 +28,10 @@ internal sealed class UpperCaseParrotAgent : AIAgent { public override string? Name => "UpperCaseParrotAgent"; - public override AgentThread GetNewThread() + public override AgentThread GetNewThread(IAgentFeatureCollection? featureCollection = null) => new CustomAgentThread(); - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) + public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, IAgentFeatureCollection? featureCollection = null) => new CustomAgentThread(serializedThread, jsonSerializerOptions); public override async Task RunAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) diff --git a/dotnet/samples/GettingStarted/Agents/Agent_Step07_3rdPartyThreadStorage/Program.cs b/dotnet/samples/GettingStarted/Agents/Agent_Step07_3rdPartyThreadStorage/Program.cs index 8986734972..98d3a27245 100644 --- a/dotnet/samples/GettingStarted/Agents/Agent_Step07_3rdPartyThreadStorage/Program.cs +++ b/dotnet/samples/GettingStarted/Agents/Agent_Step07_3rdPartyThreadStorage/Program.cs @@ -21,53 +21,200 @@ // Replace this with a vector store implementation of your choice if you want to persist the chat history to disk. VectorStore vectorStore = new InMemoryVectorStore(); -// Create the agent -AIAgent agent = new AzureOpenAIClient( - new Uri(endpoint), - new AzureCliCredential()) - .GetChatClient(deploymentName) - .CreateAIAgent(new ChatClientAgentOptions - { - Instructions = "You are good at telling jokes.", - Name = "Joker", - ChatMessageStoreFactory = ctx => +// Execute various samples showing how to use a custom ChatMessageStore with an agent. +await CustomChatMessageStore_UsingFactory_Async(); +await CustomChatMessageStore_UsingFactoryAndExistingExternalId_Async(); +await CustomChatMessageStore_PerThread_Async(); +await CustomChatMessageStore_PerRun_Async(); + +// Here we can see how to create a custom ChatMessageStore using a factory method +// provided to the agent via the ChatMessageStoreFactory option. +// This allows us to use a custom chat message store, where the consumer of the agent +// doesn't need to know anything about the storage mechanism used. +async Task CustomChatMessageStore_UsingFactory_Async() +{ + Console.WriteLine("\n--- With Factory ---\n"); + + // Create the agent + AIAgent agent = new AzureOpenAIClient( + new Uri(endpoint), + new AzureCliCredential()) + // Use a service that doesn't require storage of chat history in the service itself. + .GetChatClient(deploymentName) + .CreateAIAgent(new ChatClientAgentOptions { - // Create a new chat message store for this agent that stores the messages in a vector store. - // Each thread must get its own copy of the VectorChatMessageStore, since the store - // also contains the id that the thread is stored under. - return new VectorChatMessageStore(vectorStore, ctx.SerializedState, ctx.JsonSerializerOptions); - } - }); + Instructions = "You are good at telling jokes.", + Name = "Joker", + ChatMessageStoreFactory = ctx => + { + // Create a new chat message store for this agent that stores the messages in a vector store. + // Each thread must get its own copy of the VectorChatMessageStore, since the store + // also contains the id that the thread is stored under. + return new VectorChatMessageStore(vectorStore, ctx.SerializedState, ctx.JsonSerializerOptions, ctx.Features); + } + }); + + // Start a new thread for the agent conversation. + AgentThread thread = agent.GetNewThread(); + + // Run the agent with the thread that stores conversation history in the vector store. + Console.WriteLine(await agent.RunAsync("Tell me a joke about a pirate.", thread)); + + // Serialize the thread state, so it can be stored for later use. + // Since the chat history is stored in the vector store, the serialized thread + // only contains the guid that the messages are stored under in the vector store. + JsonElement serializedThread = thread.Serialize(); + + Console.WriteLine("\n--- Serialized thread ---\n"); + Console.WriteLine(JsonSerializer.Serialize(serializedThread, new JsonSerializerOptions { WriteIndented = true })); + + // The serialized thread can now be saved to a database, file, or any other storage mechanism + // and loaded again later. + + // Deserialize the thread state after loading from storage. + AgentThread resumedThread = agent.DeserializeThread(serializedThread); + + // Run the agent with the thread that stores conversation history in the vector store a second time. + Console.WriteLine(await agent.RunAsync("Now tell the same joke in the voice of a pirate, and add some emojis to the joke.", resumedThread)); +} + +// Here we can see how to create a custom ChatMessageStore using a factory method +// provided to the agent via the ChatMessageStoreFactory option. +// It also shows how we can pass a custom storage id at runtime to the message store using +// the VectorChatMessageStoreThreadDbKeyFeature. +// Note that not all agents or chat message stores may support this feature. +async Task CustomChatMessageStore_UsingFactoryAndExistingExternalId_Async() +{ + Console.WriteLine("\n--- With Factory and Existing External ID ---\n"); + + // Create the agent + AIAgent agent = new AzureOpenAIClient( + new Uri(endpoint), + new AzureCliCredential()) + // Use a service that doesn't require storage of chat history in the service itself. + .GetChatClient(deploymentName) + .CreateAIAgent(new ChatClientAgentOptions + { + Instructions = "You are good at telling jokes.", + Name = "Joker", + ChatMessageStoreFactory = ctx => + { + // Create a new chat message store for this agent that stores the messages in a vector store. + // Each thread must get its own copy of the VectorChatMessageStore, since the store + // also contains the id that the thread is stored under. + return new VectorChatMessageStore(vectorStore, ctx.SerializedState, ctx.JsonSerializerOptions, ctx.Features); + } + }); + + // Start a new thread for the agent conversation. + AgentThread thread = agent.GetNewThread(); + + // Run the agent with the thread that stores conversation history in the vector store. + Console.WriteLine(await agent.RunAsync("Tell me a joke about a pirate.", thread)); + + // We can access the VectorChatMessageStore via the thread's GetService method if we need to read the key under which threads are stored. + var messageStoreFromFactory = thread.GetService()!; + Console.WriteLine($"\nThread is stored in vector store under key: {messageStoreFromFactory.ThreadDbKey}"); + + // It's possible to create a new thread that uses the same chat message store id by providing + // the VectorChatMessageStoreThreadDbKeyFeature in the feature collection when creating the new thread. + AgentThread resumedThread = agent.GetNewThread( + new AgentFeatureCollection().WithFeature(new VectorChatMessageStoreThreadDbKeyFeature(messageStoreFromFactory.ThreadDbKey!))); -// Start a new thread for the agent conversation. -AgentThread thread = agent.GetNewThread(); + // Run the agent with the thread that stores conversation history in the vector store. + Console.WriteLine(await agent.RunAsync("Now tell the same joke in the voice of a pirate, and add some emojis to the joke.", resumedThread)); +} -// Run the agent with the thread that stores conversation history in the vector store. -Console.WriteLine(await agent.RunAsync("Tell me a joke about a pirate.", thread)); +// Here we can see how to create a custom ChatMessageStore and pass it to the thread +// when creating a new thread. +async Task CustomChatMessageStore_PerThread_Async() +{ + Console.WriteLine("\n--- Per Thread ---\n"); -// Serialize the thread state, so it can be stored for later use. -// Since the chat history is stored in the vector store, the serialized thread -// only contains the guid that the messages are stored under in the vector store. -JsonElement serializedThread = thread.Serialize(); + // We can also create an agent without a factory that provides a ChatMessageStore. + AIAgent agent = new AzureOpenAIClient( + new Uri(endpoint), + new AzureCliCredential()) + // Use a service that doesn't require storage of chat history in the service itself. + .GetChatClient(deploymentName) + .CreateAIAgent(new ChatClientAgentOptions + { + Instructions = "You are good at telling jokes.", + Name = "Joker" + }); -Console.WriteLine("\n--- Serialized thread ---\n"); -Console.WriteLine(JsonSerializer.Serialize(serializedThread, new JsonSerializerOptions { WriteIndented = true })); + // Instead of using a factory on the agent to create the ChatMessageStore, we can + // create a VectorChatMessageStore ourselves and register it in a feature collection. + // We can then pass the feature collection when creating a new thread. + // We also have the opportunity here to pass any id that we want for storing the chat history in the vector store. + VectorChatMessageStore perThreadMessageStore = new(vectorStore, "chat-history-1"); + AgentThread thread = agent.GetNewThread(new AgentFeatureCollection().WithFeature(perThreadMessageStore)); -// The serialized thread can now be saved to a database, file, or any other storage mechanism -// and loaded again later. + Console.WriteLine(await agent.RunAsync("Tell me a joke about a pirate.", thread)); -// Deserialize the thread state after loading from storage. -AgentThread resumedThread = agent.DeserializeThread(serializedThread); + // When serializing this thread, we'll see that it has the id from the message store stored in its state. + JsonElement serializedThread = thread.Serialize(); -// Run the agent with the thread that stores conversation history in the vector store a second time. -Console.WriteLine(await agent.RunAsync("Now tell the same joke in the voice of a pirate, and add some emojis to the joke.", resumedThread)); + Console.WriteLine("\n--- Serialized thread ---\n"); + Console.WriteLine(JsonSerializer.Serialize(serializedThread, new JsonSerializerOptions { WriteIndented = true })); +} -// We can access the VectorChatMessageStore via the thread's GetService method if we need to read the key under which threads are stored. -var messageStore = resumedThread.GetService()!; -Console.WriteLine($"\nThread is stored in vector store under key: {messageStore.ThreadDbKey}"); +// Here we can see how to create a custom ChatMessageStore for a single run using the Features option +// passed when we run the agent. +// Note that if the agent doesn't support a chat message store, it would be ignored. +async Task CustomChatMessageStore_PerRun_Async() +{ + Console.WriteLine("\n--- Per Run ---\n"); + + // We can also create an agent without a factory that provides a ChatMessageStore. + AIAgent agent = new AzureOpenAIClient( + new Uri(endpoint), + new AzureCliCredential()) + // Use a service that doesn't require storage of chat history in the service itself. + .GetChatClient(deploymentName) + .CreateAIAgent(new ChatClientAgentOptions + { + Instructions = "You are good at telling jokes.", + Name = "Joker" + }); + + // Start a new thread for the agent conversation. + AgentThread thread = agent.GetNewThread(); + + // Instead of using a factory on the agent to create the ChatMessageStore, we can + // create a VectorChatMessageStore ourselves and register it in a feature collection. + // We can then pass the feature collection to the agent when running it by using the Features option. + // The message store would only be used for the run that it's passed to. + // If the agent doesn't support a message store, it would be ignored. + // We also have the opportunity here to pass any id that we want for storing the chat history in the vector store. + VectorChatMessageStore perRunMessageStore = new(vectorStore, "chat-history-1"); + Console.WriteLine(await agent.RunAsync( + "Tell me a joke about a pirate.", + thread, + options: new AgentRunOptions() + { + Features = new AgentFeatureCollection().WithFeature(perRunMessageStore) + })); + + // When serializing this thread, we'll see that it has no messagestore state, since the messagestore was not attached to the thread, + // but just provided for the single run. Note that, depending on the circumstances, the thread may still contain other state, e.g. Memories, + // if an AIContextProvider is attached which adds memory to an agent. + JsonElement serializedThread = thread.Serialize(); + + Console.WriteLine("\n--- Serialized thread ---\n"); + Console.WriteLine(JsonSerializer.Serialize(serializedThread, new JsonSerializerOptions { WriteIndented = true })); +} namespace SampleApp { + /// + /// A feature that allows providing the thread database key for the . + /// + internal sealed class VectorChatMessageStoreThreadDbKeyFeature(string threadDbKey) + { + public string ThreadDbKey { get; } = threadDbKey; + } + /// /// A sample implementation of that stores chat messages in a vector store. /// @@ -75,29 +222,36 @@ internal sealed class VectorChatMessageStore : ChatMessageStore { private readonly VectorStore _vectorStore; - public VectorChatMessageStore(VectorStore vectorStore, JsonElement serializedStoreState, JsonSerializerOptions? jsonSerializerOptions = null) + public VectorChatMessageStore(VectorStore vectorStore, string threadDbKey) + { + this._vectorStore = vectorStore ?? throw new ArgumentNullException(nameof(vectorStore)); + this.ThreadDbKey = threadDbKey ?? throw new ArgumentNullException(nameof(threadDbKey)); + } + + public VectorChatMessageStore(VectorStore vectorStore, JsonElement serializedStoreState, JsonSerializerOptions? jsonSerializerOptions = null, IAgentFeatureCollection? features = null) { this._vectorStore = vectorStore ?? throw new ArgumentNullException(nameof(vectorStore)); - if (serializedStoreState.ValueKind is JsonValueKind.String) - { - // Here we can deserialize the thread id so that we can access the same messages as before the suspension. - this.ThreadDbKey = serializedStoreState.Deserialize(); - } + // Here we can deserialize the thread id so that we can access the same messages as before the suspension, or if + // a user provided a ConversationIdAgentFeature in the features collection, we can use that + // or finally we can generate one ourselves. + this.ThreadDbKey = serializedStoreState.ValueKind is JsonValueKind.String + ? serializedStoreState.Deserialize() + : features?.TryGet(out var threadDbKeyFeature) is true + ? threadDbKeyFeature.ThreadDbKey + : Guid.NewGuid().ToString("N"); } - public string? ThreadDbKey { get; private set; } + public string? ThreadDbKey { get; } public override async Task AddMessagesAsync(IEnumerable messages, CancellationToken cancellationToken = default) { - this.ThreadDbKey ??= Guid.NewGuid().ToString("N"); - var collection = this._vectorStore.GetCollection("ChatHistory"); await collection.EnsureCollectionExistsAsync(cancellationToken); await collection.UpsertAsync(messages.Select(x => new ChatHistoryItem() { - Key = this.ThreadDbKey + x.MessageId, + Key = this.ThreadDbKey + (string.IsNullOrWhiteSpace(x.MessageId) ? Guid.NewGuid().ToString("N") : x.MessageId), Timestamp = DateTimeOffset.UtcNow, ThreadId = this.ThreadDbKey, SerializedMessage = JsonSerializer.Serialize(x), diff --git a/dotnet/src/Microsoft.Agents.AI.A2A/A2AAgent.cs b/dotnet/src/Microsoft.Agents.AI.A2A/A2AAgent.cs index e4491970ad..d325b2a375 100644 --- a/dotnet/src/Microsoft.Agents.AI.A2A/A2AAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI.A2A/A2AAgent.cs @@ -55,8 +55,13 @@ public A2AAgent(A2AClient a2aClient, string? id = null, string? name = null, str } /// - public sealed override AgentThread GetNewThread() - => new A2AAgentThread(); + public sealed override AgentThread GetNewThread(IAgentFeatureCollection? featureCollection = null) + => new A2AAgentThread() + { + ContextId = featureCollection?.TryGet(out var conversationIdFeature) is true + ? conversationIdFeature.ConversationId + : null + }; /// /// Get a new instance using an existing context id, to continue that conversation. @@ -67,7 +72,7 @@ public AgentThread GetNewThread(string contextId) => new A2AAgentThread() { ContextId = contextId }; /// - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) + public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, IAgentFeatureCollection? featureCollection = null) => new A2AAgentThread(serializedThread, jsonSerializerOptions); /// diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AIAgent.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AIAgent.cs index eba6f84687..2921b88724 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AIAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AIAgent.cs @@ -108,6 +108,7 @@ public abstract class AIAgent /// /// Creates a new conversation thread that is compatible with this agent. /// + /// An optional feature collection to override or provide additional context or capabilities to the thread where the thread supports these features. /// A new instance ready for use with this agent. /// /// @@ -121,13 +122,14 @@ public abstract class AIAgent /// may be deferred until first use to optimize performance. /// /// - public abstract AgentThread GetNewThread(); + public abstract AgentThread GetNewThread(IAgentFeatureCollection? featureCollection = null); /// /// Deserializes an agent thread from its JSON serialized representation. /// /// A containing the serialized thread state. /// Optional settings to customize the deserialization process. + /// An optional feature collection to override or provide additional context or capabilities to the thread where the thread supports these features. /// A restored instance with the state from . /// The is not in the expected format. /// The serialized data is invalid or cannot be deserialized. @@ -136,7 +138,7 @@ public abstract class AIAgent /// allowing conversations to resume across application restarts or be migrated between /// different agent instances. /// - public abstract AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null); + public abstract AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, IAgentFeatureCollection? featureCollection = null); /// /// Run the agent with no message assuming that all required instructions are already provided to the agent or on the thread. diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentRunOptions.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentRunOptions.cs index 9cd6d51680..3976c9c765 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentRunOptions.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentRunOptions.cs @@ -34,6 +34,7 @@ public AgentRunOptions(AgentRunOptions options) this.ContinuationToken = options.ContinuationToken; this.AllowBackgroundResponses = options.AllowBackgroundResponses; this.AdditionalProperties = options.AdditionalProperties?.Clone(); + this.Features = options.Features; } /// @@ -90,4 +91,9 @@ public AgentRunOptions(AgentRunOptions options) /// preserving implementation-specific details or extending the options with custom data. /// public AdditionalPropertiesDictionary? AdditionalProperties { get; set; } + + /// + /// Gets or sets the collection of features provided by the caller and middleware for this run. + /// + public IAgentFeatureCollection? Features { get; set; } } diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentThread.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentThread.cs index 4794457f41..bfb52e021b 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentThread.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentThread.cs @@ -26,8 +26,8 @@ namespace Microsoft.Agents.AI; /// Chat history reduction, e.g. where messages needs to be summarized or truncated to reduce the size. /// /// An is always constructed by an so that the -/// can attach any necessary behaviors to the . See the -/// and methods for more information. +/// can attach any necessary behaviors to the . See the +/// and methods for more information. /// /// /// Because of these behaviors, an may not be reusable across different agents, since each agent @@ -37,13 +37,13 @@ namespace Microsoft.Agents.AI; /// To support conversations that may need to survive application restarts or separate service requests, an can be serialized /// and deserialized, so that it can be saved in a persistent store. /// The provides the method to serialize the thread to a -/// and the method +/// and the method /// can be used to deserialize the thread. /// /// /// -/// -/// +/// +/// public abstract class AgentThread { /// diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/DelegatingAIAgent.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/DelegatingAIAgent.cs index 353c82c996..a542a841a3 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/DelegatingAIAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/DelegatingAIAgent.cs @@ -74,11 +74,11 @@ protected DelegatingAIAgent(AIAgent innerAgent) } /// - public override AgentThread GetNewThread() => this.InnerAgent.GetNewThread(); + public override AgentThread GetNewThread(IAgentFeatureCollection? featureCollection = null) => this.InnerAgent.GetNewThread(featureCollection); /// - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) - => this.InnerAgent.DeserializeThread(serializedThread, jsonSerializerOptions); + public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, IAgentFeatureCollection? featureCollection = null) + => this.InnerAgent.DeserializeThread(serializedThread, jsonSerializerOptions, featureCollection); /// public override Task RunAsync( diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/Features/AgentFeatureCollection.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/Features/AgentFeatureCollection.cs new file mode 100644 index 0000000000..df157f454c --- /dev/null +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/Features/AgentFeatureCollection.cs @@ -0,0 +1,209 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Agents.AI; + +#pragma warning disable CA1043 // Use Integral Or String Argument For Indexers + +/// +/// Default implementation for . +/// +[DebuggerDisplay("Count = {GetCount()}")] +[DebuggerTypeProxy(typeof(FeatureCollectionDebugView))] +public class AgentFeatureCollection : IAgentFeatureCollection +{ + private readonly IAgentFeatureCollection? _innerCollection; + private Dictionary? _features; + private volatile int _containerRevision; + + /// + /// Initializes a new instance of . + /// + public AgentFeatureCollection() + { + } + + /// + /// Initializes a new instance of with the specified initial capacity. + /// + /// The initial number of elements that the collection can contain. + /// is less than 0 + public AgentFeatureCollection(int initialCapacity) + { + Throw.IfLessThan(initialCapacity, 0); + this._features = new(initialCapacity); + } + + /// + /// Initializes a new instance of with the specified inner collection. + /// + /// The inner collection. + /// + /// + /// When providing an inner collection, and if a feature is not found in this collection, + /// an attempt will be made to retrieve it from the inner collection as a fallback. + /// + /// + /// The method will only remove features from this collection + /// and not from the inner collection. When removing a feature from this collection, and + /// it exists in the inner collection, it will still be retrievable from the inner collection. + /// + /// + public AgentFeatureCollection(IAgentFeatureCollection innerCollection) + { + this._innerCollection = Throw.IfNull(innerCollection); + } + + /// + public int Revision + { + get { return this._containerRevision + (this._innerCollection?.Revision ?? 0); } + } + + /// + public bool IsReadOnly { get { return false; } } + + IEnumerator IEnumerable.GetEnumerator() + { + return this.GetEnumerator(); + } + + /// + public IEnumerator> GetEnumerator() + { + if (this._features is not { Count: > 0 }) + { + IEnumerable> e = ((IEnumerable>?)this._innerCollection) ?? []; + return e.GetEnumerator(); + } + + if (this._innerCollection is null) + { + return this._features.GetEnumerator(); + } + + if (this._innerCollection is AgentFeatureCollection innerCollection && innerCollection._features is not { Count: > 0 }) + { + return this._features.GetEnumerator(); + } + + return YieldAll(); + + IEnumerator> YieldAll() + { + HashSet set = []; + + foreach (var entry in this._features) + { + set.Add(entry.Key); + yield return entry; + } + + foreach (var entry in this._innerCollection.Where(x => !set.Contains(x.Key))) + { + yield return entry; + } + } + } + + /// + public bool TryGet([MaybeNullWhen(false)] out TFeature feature) + where TFeature : notnull + { + if (this.TryGet(typeof(TFeature), out var obj)) + { + feature = (TFeature)obj; + return true; + } + + feature = default; + return false; + } + + /// + public bool TryGet(Type type, [MaybeNullWhen(false)] out object feature) + { + if (this._features?.TryGetValue(type, out var obj) is true) + { + feature = obj; + return true; + } + + if (this._innerCollection?.TryGet(type, out var defaultFeature) is true) + { + feature = defaultFeature; + return true; + } + + feature = default; + return false; + } + + /// + public void Set(TFeature instance) + where TFeature : notnull + { + Throw.IfNull(instance); + + this._features ??= new(); + this._features[typeof(TFeature)] = instance; + this._containerRevision++; + } + + /// + public void Remove() + where TFeature : notnull + => this.Remove(typeof(TFeature)); + + /// + public void Remove(Type type) + { + if (this._features?.Remove(type) is true) + { + this._containerRevision++; + } + } + + // Used by the debugger. Count over enumerable is required to get the correct value. + private int GetCount() => this.Count(); + + private sealed class FeatureCollectionDebugView(AgentFeatureCollection features) + { + private readonly AgentFeatureCollection _features = features; + + [DebuggerBrowsable(DebuggerBrowsableState.RootHidden)] + public DictionaryItemDebugView[] Items => this._features.Select(pair => new DictionaryItemDebugView(pair)).ToArray(); + } + + /// + /// Defines a key/value pair for displaying an item of a dictionary by a debugger. + /// + [DebuggerDisplay("{Value}", Name = "[{Key}]")] + internal readonly struct DictionaryItemDebugView + { + public DictionaryItemDebugView(TKey key, TValue value) + { + this.Key = key; + this.Value = value; + } + + public DictionaryItemDebugView(KeyValuePair keyValue) + { + this.Key = keyValue.Key; + this.Value = keyValue.Value; + } + + [DebuggerBrowsable(DebuggerBrowsableState.Collapsed)] + public TKey Key { get; } + + [DebuggerBrowsable(DebuggerBrowsableState.Collapsed)] + public TValue Value { get; } + } +} diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/Features/AgentFeatureCollectionExtensions.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/Features/AgentFeatureCollectionExtensions.cs new file mode 100644 index 0000000000..95641858b7 --- /dev/null +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/Features/AgentFeatureCollectionExtensions.cs @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft. All rights reserved. + +namespace Microsoft.Agents.AI; + +/// +/// Extension methods for . +/// +public static class AgentFeatureCollectionExtensions +{ + /// + /// Adds the specified feature to the collection and returns the collection. + /// + /// The feature key. + /// The feature collection to add the new feature to. + /// The feature to add to the collection. + /// The updated collection. + public static IAgentFeatureCollection WithFeature(this IAgentFeatureCollection features, TFeature feature) + where TFeature : notnull + { + features.Set(feature); + return features; + } +} diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/Features/ConversationIdAgentFeature.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/Features/ConversationIdAgentFeature.cs new file mode 100644 index 0000000000..2cd267197f --- /dev/null +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/Features/ConversationIdAgentFeature.cs @@ -0,0 +1,29 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Agents.AI; + +/// +/// An agent feature that allows providing a conversation identifier. +/// +/// +/// This feature allows a user to provide a specific identifier for chat history when stored in the underlying AI service. +/// +public class ConversationIdAgentFeature +{ + /// + /// Initializes a new instance of the class with the specified thread + /// identifier. + /// + /// The unique identifier of the thread required by the underlying AI service. Cannot be or empty. + public ConversationIdAgentFeature(string conversationId) + { + this.ConversationId = Throw.IfNullOrWhitespace(conversationId); + } + + /// + /// Gets the conversation identifier. + /// + public string ConversationId { get; } +} diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/Features/IAgentFeatureCollection.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/Features/IAgentFeatureCollection.cs new file mode 100644 index 0000000000..dca17dc668 --- /dev/null +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/Features/IAgentFeatureCollection.cs @@ -0,0 +1,72 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; + +namespace Microsoft.Agents.AI; + +#pragma warning disable CA1043 // Use Integral Or String Argument For Indexers +#pragma warning disable CA1716 // Identifiers should not match keywords + +/// +/// Represents a collection of Agent features. +/// +public interface IAgentFeatureCollection : IEnumerable> +{ + /// + /// Indicates if the collection can be modified. + /// + bool IsReadOnly { get; } + + /// + /// Incremented for each modification and can be used to verify cached results. + /// + int Revision { get; } + + /// + /// Attempts to retrieve a feature of the specified type. + /// + /// The type of the feature to retrieve. + /// When this method returns, contains the feature of type if found; otherwise, the + /// default value for the type. + /// + /// if the feature of type was successfully retrieved; + /// otherwise, . + /// + bool TryGet([MaybeNullWhen(false)] out TFeature feature) + where TFeature : notnull; + + /// + /// Attempts to retrieve a feature of the specified type. + /// + /// The type of the feature to get. + /// When this method returns, contains the feature of type if found; otherwise, the + /// default value for the type. + /// + /// if the feature of type was successfully retrieved; + /// otherwise, . + /// + bool TryGet(Type type, [MaybeNullWhen(false)] out object feature); + + /// + /// Remove a feature from the collection. + /// + /// The feature key. + void Remove() + where TFeature : notnull; + + /// + /// Remove a feature from the collection. + /// + /// The type of the feature to remove. + void Remove(Type type); + + /// + /// Sets the given feature in the collection. + /// + /// The feature key. + /// The feature value. + void Set(TFeature instance) + where TFeature : notnull; +} diff --git a/dotnet/src/Microsoft.Agents.AI.CopilotStudio/CopilotStudioAgent.cs b/dotnet/src/Microsoft.Agents.AI.CopilotStudio/CopilotStudioAgent.cs index 6ca2f38d3d..c689984537 100644 --- a/dotnet/src/Microsoft.Agents.AI.CopilotStudio/CopilotStudioAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI.CopilotStudio/CopilotStudioAgent.cs @@ -42,8 +42,13 @@ public CopilotStudioAgent(CopilotClient client, ILoggerFactory? loggerFactory = } /// - public sealed override AgentThread GetNewThread() - => new CopilotStudioAgentThread(); + public sealed override AgentThread GetNewThread(IAgentFeatureCollection? featureCollection = null) + => new CopilotStudioAgentThread() + { + ConversationId = featureCollection?.TryGet(out var conversationIdFeature) is true + ? conversationIdFeature.ConversationId + : null + }; /// /// Get a new instance using an existing conversation id, to continue that conversation. @@ -54,7 +59,7 @@ public AgentThread GetNewThread(string conversationId) => new CopilotStudioAgentThread() { ConversationId = conversationId }; /// - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) + public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, IAgentFeatureCollection? featureCollection = null) => new CopilotStudioAgentThread(serializedThread, jsonSerializerOptions); /// diff --git a/dotnet/src/Microsoft.Agents.AI.DurableTask/DurableAIAgent.cs b/dotnet/src/Microsoft.Agents.AI.DurableTask/DurableAIAgent.cs index 1a117aff14..fc8ca78682 100644 --- a/dotnet/src/Microsoft.Agents.AI.DurableTask/DurableAIAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI.DurableTask/DurableAIAgent.cs @@ -33,21 +33,17 @@ internal DurableAIAgent(TaskOrchestrationContext context, string agentName) /// Creates a new agent thread for this agent using a random session ID. /// /// A new agent thread. - public override AgentThread GetNewThread() + public override AgentThread GetNewThread(IAgentFeatureCollection? featureCollection = null) { AgentSessionId sessionId = this._context.NewAgentSessionId(this._agentName); return new DurableAgentThread(sessionId); } - /// - /// Deserializes an agent thread from JSON. - /// - /// The serialized thread data. - /// Optional JSON serializer options. - /// The deserialized agent thread. + /// public override AgentThread DeserializeThread( JsonElement serializedThread, - JsonSerializerOptions? jsonSerializerOptions = null) + JsonSerializerOptions? jsonSerializerOptions = null, + IAgentFeatureCollection? featureCollection = null) { return DurableAgentThread.Deserialize(serializedThread, jsonSerializerOptions); } diff --git a/dotnet/src/Microsoft.Agents.AI.DurableTask/DurableAIAgentProxy.cs b/dotnet/src/Microsoft.Agents.AI.DurableTask/DurableAIAgentProxy.cs index 58f9598a7e..0078266896 100644 --- a/dotnet/src/Microsoft.Agents.AI.DurableTask/DurableAIAgentProxy.cs +++ b/dotnet/src/Microsoft.Agents.AI.DurableTask/DurableAIAgentProxy.cs @@ -13,12 +13,13 @@ internal class DurableAIAgentProxy(string name, IDurableAgentClient agentClient) public override AgentThread DeserializeThread( JsonElement serializedThread, - JsonSerializerOptions? jsonSerializerOptions = null) + JsonSerializerOptions? jsonSerializerOptions = null, + IAgentFeatureCollection? featureCollection = null) { return DurableAgentThread.Deserialize(serializedThread, jsonSerializerOptions); } - public override AgentThread GetNewThread() + public override AgentThread GetNewThread(IAgentFeatureCollection? featureCollection = null) { return new DurableAgentThread(AgentSessionId.WithRandomKey(this.Name!)); } diff --git a/dotnet/src/Microsoft.Agents.AI.Purview/PurviewAgent.cs b/dotnet/src/Microsoft.Agents.AI.Purview/PurviewAgent.cs index fd2a1950e9..0fb0ee13f3 100644 --- a/dotnet/src/Microsoft.Agents.AI.Purview/PurviewAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI.Purview/PurviewAgent.cs @@ -30,15 +30,15 @@ public PurviewAgent(AIAgent innerAgent, PurviewWrapper purviewWrapper) } /// - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) + public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, IAgentFeatureCollection? featureCollection = null) { - return this._innerAgent.DeserializeThread(serializedThread, jsonSerializerOptions); + return this._innerAgent.DeserializeThread(serializedThread, jsonSerializerOptions, featureCollection); } /// - public override AgentThread GetNewThread() + public override AgentThread GetNewThread(IAgentFeatureCollection? featureCollection = null) { - return this._innerAgent.GetNewThread(); + return this._innerAgent.GetNewThread(featureCollection); } /// diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowHostAgent.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowHostAgent.cs index 98dc5903bf..b1767b6ea4 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowHostAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowHostAgent.cs @@ -61,9 +61,9 @@ private async ValueTask ValidateWorkflowAsync() protocol.ThrowIfNotChatProtocol(); } - public override AgentThread GetNewThread() => new WorkflowThread(this._workflow, this.GenerateNewId(), this._executionEnvironment, this._checkpointManager); + public override AgentThread GetNewThread(IAgentFeatureCollection? featureCollection = null) => new WorkflowThread(this._workflow, this.GenerateNewId(), this._executionEnvironment, this._checkpointManager); - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) + public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, IAgentFeatureCollection? featureCollection = null) => new WorkflowThread(this._workflow, serializedThread, this._executionEnvironment, this._checkpointManager, jsonSerializerOptions); private async ValueTask UpdateThreadAsync(IEnumerable messages, AgentThread? thread = null, CancellationToken cancellationToken = default) diff --git a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs index 30665fecf3..5ba6346570 100644 --- a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs @@ -270,7 +270,7 @@ public override async IAsyncEnumerable RunStreamingAsync this.UpdateThreadWithTypeAndConversationId(safeThread, chatResponse.ConversationId); // To avoid inconsistent state we only notify the thread of the input messages if no error occurs after the initial request. - await NotifyMessageStoreOfNewMessagesAsync(safeThread, inputMessages.Concat(aiContextProviderMessages ?? []).Concat(chatResponse.Messages), cancellationToken).ConfigureAwait(false); + await NotifyMessageStoreOfNewMessagesAsync(safeThread, inputMessages.Concat(aiContextProviderMessages ?? []).Concat(chatResponse.Messages), options, cancellationToken).ConfigureAwait(false); // Notify the AIContextProvider of all new messages. await NotifyAIContextProviderOfSuccessAsync(safeThread, inputMessages, aiContextProviderMessages, chatResponse.Messages, cancellationToken).ConfigureAwait(false); @@ -286,11 +286,20 @@ public override async IAsyncEnumerable RunStreamingAsync : this.ChatClient.GetService(serviceType, serviceKey)); /// - public override AgentThread GetNewThread() + public override AgentThread GetNewThread(IAgentFeatureCollection? featureCollection = null) => new ChatClientAgentThread { - MessageStore = this._agentOptions?.ChatMessageStoreFactory?.Invoke(new() { SerializedState = default, JsonSerializerOptions = null }), - AIContextProvider = this._agentOptions?.AIContextProviderFactory?.Invoke(new() { SerializedState = default, JsonSerializerOptions = null }) + ConversationId = featureCollection?.TryGet(out var conversationIdAgentFeature) is true + ? conversationIdAgentFeature.ConversationId + : null, + MessageStore = + featureCollection?.TryGet(out var chatMessageStoreFeature) is true + ? chatMessageStoreFeature + : this._agentOptions?.ChatMessageStoreFactory?.Invoke(new() { SerializedState = default, Features = featureCollection, JsonSerializerOptions = null }), + AIContextProvider = + featureCollection?.TryGet(out var aIContextProviderFeature) is true + ? aIContextProviderFeature + : this._agentOptions?.AIContextProviderFactory?.Invoke(new() { SerializedState = default, Features = featureCollection, JsonSerializerOptions = null }) }; /// @@ -346,15 +355,21 @@ public AgentThread GetNewThread(ChatMessageStore chatMessageStore) }; /// - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) + public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, IAgentFeatureCollection? featureCollection = null) { - Func? chatMessageStoreFactory = this._agentOptions?.ChatMessageStoreFactory is null ? - null : - (jse, jso) => this._agentOptions.ChatMessageStoreFactory.Invoke(new() { SerializedState = jse, JsonSerializerOptions = jso }); - - Func? aiContextProviderFactory = this._agentOptions?.AIContextProviderFactory is null ? - null : - (jse, jso) => this._agentOptions.AIContextProviderFactory.Invoke(new() { SerializedState = jse, JsonSerializerOptions = jso }); + Func? chatMessageStoreFactory = + featureCollection?.TryGet(out var chatMessageStoreFeature) is true + ? (jse, jso) => chatMessageStoreFeature + : this._agentOptions?.ChatMessageStoreFactory is not null + ? (jse, jso) => this._agentOptions.ChatMessageStoreFactory.Invoke(new() { SerializedState = jse, Features = featureCollection, JsonSerializerOptions = jso }) + : null; + + Func? aiContextProviderFactory = + featureCollection?.TryGet(out var aiContextProviderFeature) is true + ? (jse, jso) => aiContextProviderFeature + : this._agentOptions?.AIContextProviderFactory is not null + ? (jse, jso) => this._agentOptions.AIContextProviderFactory.Invoke(new() { SerializedState = jse, Features = featureCollection, JsonSerializerOptions = jso }) + : null; return new ChatClientAgentThread( serializedThread, @@ -413,7 +428,7 @@ private async Task RunCoreAsync(out var chatMessageStoreFeature) is true) + { + messageStore = chatMessageStoreFeature; + } + // Add any existing messages from the thread to the messages to be sent to the chat client. - if (typedThread.MessageStore is not null) + if (messageStore is not null) { - inputMessagesForChatClient.AddRange(await typedThread.MessageStore.GetMessagesAsync(cancellationToken).ConfigureAwait(false)); + inputMessagesForChatClient.AddRange(await messageStore.GetMessagesAsync(cancellationToken).ConfigureAwait(false)); } // If we have an AIContextProvider, we should get context from it, and update our @@ -717,10 +741,17 @@ private void UpdateThreadWithTypeAndConversationId(ChatClientAgentThread thread, } } - private static Task NotifyMessageStoreOfNewMessagesAsync(ChatClientAgentThread thread, IEnumerable newMessages, CancellationToken cancellationToken) + private static Task NotifyMessageStoreOfNewMessagesAsync(ChatClientAgentThread thread, IEnumerable newMessages, AgentRunOptions? runOptions, CancellationToken cancellationToken) { var messageStore = thread.MessageStore; + // If the caller provided an override message store via run options, we should use that instead of the message store + // on the thread. + if (runOptions?.Features?.TryGet(out var chatMessageStoreFeature) is true) + { + messageStore = chatMessageStoreFeature; + } + // Only notify the message store if we have one. // If we don't have one, it means that the chat history is service managed and the underlying service is responsible for storing messages. if (messageStore is not null) diff --git a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentOptions.cs b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentOptions.cs index f83e6912d5..2c0041ba7d 100644 --- a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentOptions.cs +++ b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentOptions.cs @@ -128,6 +128,11 @@ public class AIContextProviderFactoryContext /// Gets or sets the JSON serialization options to use when deserializing the . /// public JsonSerializerOptions? JsonSerializerOptions { get; set; } + + /// + /// Gets or sets the collection of features provided by the caller and middleware. + /// + public IAgentFeatureCollection? Features { get; set; } } /// @@ -145,5 +150,10 @@ public class ChatMessageStoreFactoryContext /// Gets or sets the JSON serialization options to use when deserializing the . /// public JsonSerializerOptions? JsonSerializerOptions { get; set; } + + /// + /// Gets or sets the collection of features provided by the caller and middleware. + /// + public IAgentFeatureCollection? Features { get; set; } } } diff --git a/dotnet/tests/Microsoft.Agents.AI.A2A.UnitTests/A2AAgentTests.cs b/dotnet/tests/Microsoft.Agents.AI.A2A.UnitTests/A2AAgentTests.cs index 05c7f5ba08..ccff3e5f43 100644 --- a/dotnet/tests/Microsoft.Agents.AI.A2A.UnitTests/A2AAgentTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.A2A.UnitTests/A2AAgentTests.cs @@ -73,6 +73,24 @@ public void Constructor_WithDefaultParameters_UsesBaseProperties() Assert.Equal(agent.Id, agent.DisplayName); } + [Fact] + public void GetNewThread_WithStringFeature_UsesItForContextId() + { + // Arrange + var contextIdFeature = new ConversationIdAgentFeature("feature-context-id"); + var agentWithFeature = new A2AAgent(this._a2aClient); + + // Act + var features = new AgentFeatureCollection(); + features.Set(contextIdFeature); + var thread = agentWithFeature.GetNewThread(features); + + // Assert + Assert.IsType(thread); + var a2aThread = (A2AAgentThread)thread; + Assert.Equal(contextIdFeature.ConversationId, a2aThread.ContextId); + } + [Fact] public async Task RunAsync_AllowsNonUserRoleMessagesAsync() { diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AIAgentTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AIAgentTests.cs index 5111a97ad1..0f265466a3 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AIAgentTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AIAgentTests.cs @@ -344,10 +344,10 @@ public abstract class TestAgentThread : AgentThread; private sealed class MockAgent : AIAgent { - public override AgentThread GetNewThread() + public override AgentThread GetNewThread(IAgentFeatureCollection? featureCollection = null) => throw new NotImplementedException(); - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) + public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, IAgentFeatureCollection? featureCollection = null) => throw new NotImplementedException(); public override Task RunAsync( diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentFeatureCollectionTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentFeatureCollectionTests.cs new file mode 100644 index 0000000000..9d3a3e7c66 --- /dev/null +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentFeatureCollectionTests.cs @@ -0,0 +1,190 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Linq; + +namespace Microsoft.Agents.AI.Abstractions.UnitTests; + +/// +/// Contains unit tests for the class. +/// +public class AgentFeatureCollectionTests +{ + [Fact] + public void Feature_RoundTrips() + { + // Arrange. + var interfaces = new AgentFeatureCollection(); + var thing = new Thing(); + + // Act. + interfaces.Set(thing); + Assert.True(interfaces.TryGet(out var actualThing)); + + // Assert. + Assert.Same(actualThing, thing); + Assert.Equal(1, interfaces.Revision); + } + + [Fact] + public void RemoveOfT_Removes() + { + // Arrange. + var interfaces = new AgentFeatureCollection(); + var thing = new Thing(); + + interfaces.Set(thing); + Assert.True(interfaces.TryGet(out _)); + + // Act. + interfaces.Remove(); + + // Assert. + Assert.False(interfaces.TryGet(out _)); + Assert.Equal(2, interfaces.Revision); + } + + [Fact] + public void Remove_Removes() + { + // Arrange. + var interfaces = new AgentFeatureCollection(); + var thing = new Thing(); + + interfaces.Set(thing); + Assert.True(interfaces.TryGet(out _)); + + // Act. + interfaces.Remove(typeof(IThing)); + + // Assert. + Assert.False(interfaces.TryGet(out _)); + Assert.Equal(2, interfaces.Revision); + } + + [Fact] + public void TryGetMissingFeature_ReturnsFalse() + { + // Arrange. + var interfaces = new AgentFeatureCollection(); + + // Act & Assert. + Assert.False(interfaces.TryGet(out var actualThing)); + Assert.Null(actualThing); + } + + [Fact] + public void Set_Null_Throws() + { + // Arrange. + var interfaces = new AgentFeatureCollection(); + + // Act & Assert. + Assert.Throws(() => interfaces.Set(null!)); + } + + [Fact] + public void IsReadOnly_DefaultsToFalse() + { + // Arrange. + var interfaces = new AgentFeatureCollection(); + + // Act & Assert. + Assert.False(interfaces.IsReadOnly); + } + + [Fact] + public void TryGetOfT_FallsBackToInnerCollection() + { + // Arrange. + var inner = new AgentFeatureCollection(); + var thing = new Thing(); + inner.Set(thing); + var outer = new AgentFeatureCollection(inner); + + // Act & Assert. + Assert.True(outer.TryGet(out var actualThing)); + Assert.Same(actualThing, thing); + } + + [Fact] + public void TryGetOfT_OverridesInnerWithOuterCollection() + { + // Arrange. + var inner = new AgentFeatureCollection(); + var innerThing = new Thing(); + inner.Set(innerThing); + + var outer = new AgentFeatureCollection(inner); + var outerThing = new Thing(); + outer.Set(outerThing); + + // Act & Assert. + Assert.True(outer.TryGet(out var actualThing)); + Assert.Same(outerThing, actualThing); + } + + [Fact] + public void TryGet_FallsBackToInnerCollection() + { + // Arrange. + var inner = new AgentFeatureCollection(); + var thing = new Thing(); + inner.Set(thing); + var outer = new AgentFeatureCollection(inner); + + // Act & Assert. + Assert.True(outer.TryGet(typeof(IThing), out var actualThing)); + Assert.Same(actualThing, thing); + } + + [Fact] + public void TryGet_OverridesInnerWithOuterCollection() + { + // Arrange. + var inner = new AgentFeatureCollection(); + var innerThing = new Thing(); + inner.Set(innerThing); + + var outer = new AgentFeatureCollection(inner); + var outerThing = new Thing(); + outer.Set(outerThing); + + // Act & Assert. + Assert.True(outer.TryGet(typeof(IThing), out var actualThing)); + Assert.Same(outerThing, actualThing); + } + + [Fact] + public void Enumerate_OverridesInnerWithOuterCollection() + { + // Arrange. + var inner = new AgentFeatureCollection(); + var innerThing = new Thing(); + inner.Set(innerThing); + + var outer = new AgentFeatureCollection(inner); + var outerThing = new Thing(); + outer.Set(outerThing); + + // Act. + var items = outer.ToList(); + + // Assert. + Assert.Single(items); + Assert.Same(outerThing, items.First().Value as IThing); + } + + private interface IThing + { + string Hello(); + } + + private sealed class Thing : IThing + { + public string Hello() + { + return "World"; + } + } +} diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentRunOptionsTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentRunOptionsTests.cs index 7460ea4623..ecaa9502d7 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentRunOptionsTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentRunOptionsTests.cs @@ -23,7 +23,8 @@ public void CloningConstructorCopiesProperties() { ["key1"] = "value1", ["key2"] = 42 - } + }, + Features = new AgentFeatureCollection() }; // Act @@ -37,6 +38,7 @@ public void CloningConstructorCopiesProperties() Assert.NotSame(options.AdditionalProperties, clone.AdditionalProperties); Assert.Equal("value1", clone.AdditionalProperties["key1"]); Assert.Equal(42, clone.AdditionalProperties["key2"]); + Assert.Same(options.Features, clone.Features); } [Fact] diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/DelegatingAIAgentTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/DelegatingAIAgentTests.cs index 4dca99a77c..b6a72110fb 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/DelegatingAIAgentTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/DelegatingAIAgentTests.cs @@ -2,6 +2,7 @@ using System; using System.Collections.Generic; +using System.Text.Json; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI; @@ -34,7 +35,12 @@ public DelegatingAIAgentTests() this._innerAgentMock.Setup(x => x.Id).Returns("test-agent-id"); this._innerAgentMock.Setup(x => x.Name).Returns("Test Agent"); this._innerAgentMock.Setup(x => x.Description).Returns("Test Description"); - this._innerAgentMock.Setup(x => x.GetNewThread()).Returns(this._testThread); + this._innerAgentMock.Setup(x => x.GetNewThread(It.IsAny())).Returns(this._testThread); + this._innerAgentMock.Setup(x => x.DeserializeThread( + It.IsAny(), + It.IsAny(), + It.IsAny())) + .Returns(this._testThread); this._innerAgentMock .Setup(x => x.RunAsync( @@ -135,11 +141,29 @@ public void Description_DelegatesToInnerAgent() public void GetNewThread_DelegatesToInnerAgent() { // Act - var thread = this._delegatingAgent.GetNewThread(); + var featureCollection = new AgentFeatureCollection(); + var thread = this._delegatingAgent.GetNewThread(featureCollection); // Assert Assert.Same(this._testThread, thread); - this._innerAgentMock.Verify(x => x.GetNewThread(), Times.Once); + this._innerAgentMock.Verify(x => x.GetNewThread(featureCollection), Times.Once); + } + + /// + /// Verify that DeserializeThread delegates to inner agent. + /// + [Fact] + public void DeserializeThread_DelegatesToInnerAgent() + { + // Act + var featureCollection = new AgentFeatureCollection(); + var jsonElement = new JsonElement(); + var jso = new JsonSerializerOptions(); + var thread = this._delegatingAgent.DeserializeThread(jsonElement, jso, featureCollection); + + // Assert + Assert.Same(this._testThread, thread); + this._innerAgentMock.Verify(x => x.DeserializeThread(jsonElement, jso, featureCollection), Times.Once); } /// diff --git a/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/BasicStreamingTests.cs b/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/BasicStreamingTests.cs index 5bc4e8afad..32603341a1 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/BasicStreamingTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/BasicStreamingTests.cs @@ -286,12 +286,12 @@ public FakeChatClientAgent() public override string? Description { get; } - public override AgentThread GetNewThread() + public override AgentThread GetNewThread(IAgentFeatureCollection? featureCollection = null) { return new FakeInMemoryAgentThread(); } - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) + public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, IAgentFeatureCollection? featureCollection = null) { return new FakeInMemoryAgentThread(serializedThread, jsonSerializerOptions); } @@ -360,12 +360,12 @@ public FakeMultiMessageAgent() public override string? Description { get; } - public override AgentThread GetNewThread() + public override AgentThread GetNewThread(IAgentFeatureCollection? featureCollection = null) { return new FakeInMemoryAgentThread(); } - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) + public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, IAgentFeatureCollection? featureCollection = null) { return new FakeInMemoryAgentThread(serializedThread, jsonSerializerOptions); } diff --git a/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/SharedStateTests.cs b/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/SharedStateTests.cs index c96f2d92d0..607f94bf20 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/SharedStateTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/SharedStateTests.cs @@ -417,9 +417,9 @@ stateObj is JsonElement state && await Task.CompletedTask; } - public override AgentThread GetNewThread() => new FakeInMemoryAgentThread(); + public override AgentThread GetNewThread(IAgentFeatureCollection? featureCollection = null) => new FakeInMemoryAgentThread(); - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) + public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, IAgentFeatureCollection? featureCollection = null) { return new FakeInMemoryAgentThread(serializedThread, jsonSerializerOptions); } diff --git a/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.UnitTests/AGUIEndpointRouteBuilderExtensionsTests.cs b/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.UnitTests/AGUIEndpointRouteBuilderExtensionsTests.cs index 78a3048747..6d24d35d60 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.UnitTests/AGUIEndpointRouteBuilderExtensionsTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.UnitTests/AGUIEndpointRouteBuilderExtensionsTests.cs @@ -425,9 +425,9 @@ private sealed class MultiResponseAgent : AIAgent public override string? Description => "Agent that produces multiple text chunks"; - public override AgentThread GetNewThread() => new TestInMemoryAgentThread(); + public override AgentThread GetNewThread(IAgentFeatureCollection? featureCollection = null) => new TestInMemoryAgentThread(); - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) => + public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, IAgentFeatureCollection? featureCollection = null) => new TestInMemoryAgentThread(serializedThread, jsonSerializerOptions); public override Task RunAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) @@ -514,9 +514,9 @@ private sealed class TestAgent : AIAgent public override string? Description => "Test agent"; - public override AgentThread GetNewThread() => new TestInMemoryAgentThread(); + public override AgentThread GetNewThread(IAgentFeatureCollection? featureCollection = null) => new TestInMemoryAgentThread(); - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) => + public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, IAgentFeatureCollection? featureCollection = null) => new TestInMemoryAgentThread(serializedThread, jsonSerializerOptions); public override Task RunAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) diff --git a/dotnet/tests/Microsoft.Agents.AI.Hosting.AzureFunctions.UnitTests/TestAgent.cs b/dotnet/tests/Microsoft.Agents.AI.Hosting.AzureFunctions.UnitTests/TestAgent.cs index b0ad7ec0fe..9af6bb4a31 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Hosting.AzureFunctions.UnitTests/TestAgent.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Hosting.AzureFunctions.UnitTests/TestAgent.cs @@ -11,11 +11,12 @@ internal sealed class TestAgent(string name, string description) : AIAgent public override string? Description => description; - public override AgentThread GetNewThread() => new DummyAgentThread(); + public override AgentThread GetNewThread(IAgentFeatureCollection? featureCollection = null) => new DummyAgentThread(); public override AgentThread DeserializeThread( JsonElement serializedThread, - JsonSerializerOptions? jsonSerializerOptions = null) => new DummyAgentThread(); + JsonSerializerOptions? jsonSerializerOptions = null, + IAgentFeatureCollection? featureCollection = null) => new DummyAgentThread(); public override Task RunAsync( IEnumerable messages, diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/AgentExtensionsTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/AgentExtensionsTests.cs index f2b2bcfd6a..5af498f809 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/AgentExtensionsTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/AgentExtensionsTests.cs @@ -324,10 +324,10 @@ public TestAgent(string? name, string? description, Exception exceptionToThrow) this._exceptionToThrow = exceptionToThrow; } - public override AgentThread GetNewThread() + public override AgentThread GetNewThread(IAgentFeatureCollection? featureCollection = null) => throw new NotImplementedException(); - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) + public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, IAgentFeatureCollection? featureCollection = null) => throw new NotImplementedException(); public override string? Name { get; } diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentTests.cs index 920f9f82ee..fb13aed02b 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentTests.cs @@ -426,10 +426,10 @@ public async Task RunAsyncSetsConversationIdOnThreadWhenReturnedByChatClientAsyn } /// - /// Verify that RunAsync uses the ChatMessageStore factory when the chat client returns no conversation id. + /// Verify that RunAsync uses the default InMemoryChatMessageStore when the chat client returns no conversation id. /// [Fact] - public async Task RunAsyncUsesChatMessageStoreWhenNoConversationIdReturnedByChatClientAsync() + public async Task RunAsyncUsesDefaultInMemoryChatMessageStoreWhenNoConversationIdReturnedByChatClientAsync() { // Arrange Mock mockService = new(); @@ -438,12 +438,9 @@ public async Task RunAsyncUsesChatMessageStoreWhenNoConversationIdReturnedByChat It.IsAny>(), It.IsAny(), It.IsAny())).ReturnsAsync(new ChatResponse([new(ChatRole.Assistant, "response")])); - Mock> mockFactory = new(); - mockFactory.Setup(f => f(It.IsAny())).Returns(new InMemoryChatMessageStore()); ChatClientAgent agent = new(mockService.Object, options: new() { Instructions = "test instructions", - ChatMessageStoreFactory = mockFactory.Object }); // Act @@ -455,14 +452,13 @@ public async Task RunAsyncUsesChatMessageStoreWhenNoConversationIdReturnedByChat Assert.Equal(2, messageStore.Count); Assert.Equal("test", messageStore[0].Text); Assert.Equal("response", messageStore[1].Text); - mockFactory.Verify(f => f(It.IsAny()), Times.Once); } /// - /// Verify that RunAsync uses the default InMemoryChatMessageStore when the chat client returns no conversation id. + /// Verify that RunAsync uses the ChatMessageStore factory when the chat client returns no conversation id. /// [Fact] - public async Task RunAsyncUsesDefaultInMemoryChatMessageStoreWhenNoConversationIdReturnedByChatClientAsync() + public async Task RunAsyncUsesChatMessageStoreFactoryWhenProvidedAndNoConversationIdReturnedByChatClientAsync() { // Arrange Mock mockService = new(); @@ -471,9 +467,16 @@ public async Task RunAsyncUsesDefaultInMemoryChatMessageStoreWhenNoConversationI It.IsAny>(), It.IsAny(), It.IsAny())).ReturnsAsync(new ChatResponse([new(ChatRole.Assistant, "response")])); + + Mock mockChatMessageStore = new(); + + Mock> mockFactory = new(); + mockFactory.Setup(f => f(It.IsAny())).Returns(mockChatMessageStore.Object); + ChatClientAgent agent = new(mockService.Object, options: new() { Instructions = "test instructions", + ChatMessageStoreFactory = mockFactory.Object }); // Act @@ -481,17 +484,16 @@ public async Task RunAsyncUsesDefaultInMemoryChatMessageStoreWhenNoConversationI await agent.RunAsync([new(ChatRole.User, "test")], thread); // Assert - var messageStore = Assert.IsType(thread!.MessageStore); - Assert.Equal(2, messageStore.Count); - Assert.Equal("test", messageStore[0].Text); - Assert.Equal("response", messageStore[1].Text); + Assert.IsType(thread!.MessageStore, exactMatch: false); + mockChatMessageStore.Verify(s => s.AddMessagesAsync(It.Is>(x => x.Count() == 2), It.IsAny()), Times.Once); + mockFactory.Verify(f => f(It.IsAny()), Times.Once); } /// - /// Verify that RunAsync uses the ChatMessageStore factory when the chat client returns no conversation id. + /// Verify that RunAsync uses the ChatMessageStore provided via run params when the chat client returns no conversation id. /// [Fact] - public async Task RunAsyncUsesChatMessageStoreFactoryWhenProvidedAndNoConversationIdReturnedByChatClientAsync() + public async Task RunAsyncUsesChatMessageStoreWhenProvidedViaFeaturesAndNoConversationIdReturnedByChatClientAsync() { // Arrange Mock mockService = new(); @@ -503,23 +505,22 @@ public async Task RunAsyncUsesChatMessageStoreFactoryWhenProvidedAndNoConversati Mock mockChatMessageStore = new(); - Mock> mockFactory = new(); - mockFactory.Setup(f => f(It.IsAny())).Returns(mockChatMessageStore.Object); - ChatClientAgent agent = new(mockService.Object, options: new() { Instructions = "test instructions", - ChatMessageStoreFactory = mockFactory.Object }); + AgentFeatureCollection features = new(); + features.Set(mockChatMessageStore.Object); + // Act ChatClientAgentThread? thread = agent.GetNewThread() as ChatClientAgentThread; - await agent.RunAsync([new(ChatRole.User, "test")], thread); + await agent.RunAsync([new(ChatRole.User, "test")], thread, options: new AgentRunOptions() { Features = features }); // Assert Assert.IsType(thread!.MessageStore, exactMatch: false); + mockChatMessageStore.Verify(s => s.GetMessagesAsync(It.IsAny()), Times.Once); mockChatMessageStore.Verify(s => s.AddMessagesAsync(It.Is>(x => x.Count() == 2), It.IsAny()), Times.Once); - mockFactory.Verify(f => f(It.IsAny()), Times.Once); } /// diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_DeserializeThreadTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_DeserializeThreadTests.cs index 1fd9a71b98..7c3abdcbe7 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_DeserializeThreadTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_DeserializeThreadTests.cs @@ -76,4 +76,68 @@ public void DeserializeThread_UsesChatMessageStoreFactory_IfProvided() var typedThread = (ChatClientAgentThread)thread; Assert.Same(mockMessageStore.Object, typedThread.MessageStore); } + + [Fact] + public void DeserializeThread_UsesChatMessageStore_FromFeatureOverload() + { + // Arrange + var mockChatClient = new Mock(); + var mockMessageStore = new Mock(); + var agent = new ChatClientAgent(mockChatClient.Object, new ChatClientAgentOptions + { + Instructions = "Test instructions", + ChatMessageStoreFactory = _ => + { + Assert.Fail("ChatMessageStoreFactory should not have been called."); + return null!; + } + }); + + var json = JsonSerializer.Deserialize(""" + { + } + """, TestJsonSerializerContext.Default.JsonElement); + + // Act + var agentFeatures = new AgentFeatureCollection(); + agentFeatures.Set(mockMessageStore.Object); + var thread = agent.DeserializeThread(json, null, agentFeatures); + + // Assert + Assert.IsType(thread); + var typedThread = (ChatClientAgentThread)thread; + Assert.Same(mockMessageStore.Object, typedThread.MessageStore); + } + + [Fact] + public void DeserializeThread_UsesAIContextProvider_FromFeatureOverload() + { + // Arrange + var mockChatClient = new Mock(); + var mockContextProvider = new Mock(); + var agent = new ChatClientAgent(mockChatClient.Object, new ChatClientAgentOptions + { + Instructions = "Test instructions", + AIContextProviderFactory = _ => + { + Assert.Fail("AIContextProviderFactory should not have been called."); + return null!; + } + }); + + var json = JsonSerializer.Deserialize(""" + { + } + """, TestJsonSerializerContext.Default.JsonElement); + + // Act + var agentFeatures = new AgentFeatureCollection(); + agentFeatures.Set(mockContextProvider.Object); + var thread = agent.DeserializeThread(json, null, agentFeatures); + + // Assert + Assert.IsType(thread); + var typedThread = (ChatClientAgentThread)thread; + Assert.Same(mockContextProvider.Object, typedThread.AIContextProvider); + } } diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_GetNewThreadTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_GetNewThreadTests.cs index 43e0bef8bc..10af9bd9a5 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_GetNewThreadTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_GetNewThreadTests.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using Microsoft.Extensions.AI; using Moq; @@ -97,4 +98,79 @@ public void GetNewThread_UsesConversationId_FromTypedOverload() var typedThread = (ChatClientAgentThread)thread; Assert.Equal(TestConversationId, typedThread.ConversationId); } + + [Fact] + public void GetNewThread_UsesConversationId_FromFeatureOverload() + { + // Arrange + var mockChatClient = new Mock(); + var testConversationId = new ConversationIdAgentFeature("test_conversation_id"); + var agent = new ChatClientAgent(mockChatClient.Object); + + // Act + var agentFeatures = new AgentFeatureCollection(); + agentFeatures.Set(testConversationId); + var thread = agent.GetNewThread(agentFeatures); + + // Assert + Assert.IsType(thread); + var typedThread = (ChatClientAgentThread)thread; + Assert.Equal(testConversationId.ConversationId, typedThread.ConversationId); + } + + [Fact] + public void GetNewThread_UsesChatMessageStore_FromFeatureOverload() + { + // Arrange + var mockChatClient = new Mock(); + var mockMessageStore = new Mock(); + var agent = new ChatClientAgent(mockChatClient.Object); + + // Act + var agentFeatures = new AgentFeatureCollection(); + agentFeatures.Set(mockMessageStore.Object); + var thread = agent.GetNewThread(agentFeatures); + + // Assert + Assert.IsType(thread); + var typedThread = (ChatClientAgentThread)thread; + Assert.Same(mockMessageStore.Object, typedThread.MessageStore); + } + + [Fact] + public void GetNewThread_UsesAIContextProvider_FromFeatureOverload() + { + // Arrange + var mockChatClient = new Mock(); + var mockContextProvider = new Mock(); + var agent = new ChatClientAgent(mockChatClient.Object); + + // Act + var agentFeatures = new AgentFeatureCollection(); + agentFeatures.Set(mockContextProvider.Object); + var thread = agent.GetNewThread(agentFeatures); + + // Assert + Assert.IsType(thread); + var typedThread = (ChatClientAgentThread)thread; + Assert.Same(mockContextProvider.Object, typedThread.AIContextProvider); + } + + [Fact] + public void GetNewThread_Throws_IfBothConversationIdAndMessageStoreAreSet() + { + // Arrange + var mockChatClient = new Mock(); + var mockMessageStore = new Mock(); + var testConversationId = new ConversationIdAgentFeature("test_conversation_id"); + var agent = new ChatClientAgent(mockChatClient.Object); + + // Act & Assert + var agentFeatures = new AgentFeatureCollection(); + agentFeatures.Set(mockMessageStore.Object); + agentFeatures.Set(testConversationId); + + var exception = Assert.Throws(() => agent.GetNewThread(agentFeatures)); + Assert.Equal("Only the ConversationId or MessageStore may be set, but not both and switching from one to another is not supported.", exception.Message); + } } diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/TestAIAgent.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/TestAIAgent.cs index fb00973c78..65689c9d05 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/TestAIAgent.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/TestAIAgent.cs @@ -24,10 +24,10 @@ internal sealed class TestAIAgent : AIAgent public override string? Description => this.DescriptionFunc?.Invoke() ?? base.Description; - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) => + public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, IAgentFeatureCollection? featureCollection = null) => this.DeserializeThreadFunc(serializedThread, jsonSerializerOptions); - public override AgentThread GetNewThread() => + public override AgentThread GetNewThread(IAgentFeatureCollection? featureCollection = null) => this.GetNewThreadFunc(); public override Task RunAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) => diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/AgentWorkflowBuilderTests.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/AgentWorkflowBuilderTests.cs index 0437fc7695..8dd7b438ae 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/AgentWorkflowBuilderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/AgentWorkflowBuilderTests.cs @@ -135,10 +135,10 @@ private class DoubleEchoAgent(string name) : AIAgent { public override string Name => name; - public override AgentThread GetNewThread() + public override AgentThread GetNewThread(IAgentFeatureCollection? featureCollection = null) => new DoubleEchoAgentThread(); - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) + public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, IAgentFeatureCollection? featureCollection = null) => new DoubleEchoAgentThread(); public override Task RunAsync( diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/InProcessExecutionTests.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/InProcessExecutionTests.cs index e134f10aa7..00de718eb9 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/InProcessExecutionTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/InProcessExecutionTests.cs @@ -144,10 +144,12 @@ public SimpleTestAgent(string name) public override string Name { get; } - public override AgentThread GetNewThread() => new SimpleTestAgentThread(); + public override AgentThread GetNewThread(IAgentFeatureCollection? featureCollection = null) => new SimpleTestAgentThread(); - public override AgentThread DeserializeThread(System.Text.Json.JsonElement serializedThread, - System.Text.Json.JsonSerializerOptions? jsonSerializerOptions = null) => new SimpleTestAgentThread(); + public override AgentThread DeserializeThread( + System.Text.Json.JsonElement serializedThread, + System.Text.Json.JsonSerializerOptions? jsonSerializerOptions = null, + IAgentFeatureCollection? featureCollection = null) => new SimpleTestAgentThread(); public override Task RunAsync( IEnumerable messages, diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/RepresentationTests.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/RepresentationTests.cs index 9cf460e658..768ddbda73 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/RepresentationTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/RepresentationTests.cs @@ -24,10 +24,10 @@ private sealed class TestExecutor() : Executor("TestExecutor") private sealed class TestAgent : AIAgent { - public override AgentThread GetNewThread() + public override AgentThread GetNewThread(IAgentFeatureCollection? featureCollection = null) => throw new NotImplementedException(); - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) + public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, IAgentFeatureCollection? featureCollection = null) => throw new NotImplementedException(); public override Task RunAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) => diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/06_GroupChat_Workflow.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/06_GroupChat_Workflow.cs index a0e57006ed..2633e4bdf7 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/06_GroupChat_Workflow.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/06_GroupChat_Workflow.cs @@ -60,10 +60,10 @@ internal sealed class HelloAgent(string id = nameof(HelloAgent)) : AIAgent public override string Id => id; public override string? Name => id; - public override AgentThread GetNewThread() + public override AgentThread GetNewThread(IAgentFeatureCollection? featureCollection = null) => new HelloAgentThread(); - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) + public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, IAgentFeatureCollection? featureCollection = null) => new HelloAgentThread(); public override async Task RunAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/SpecializedExecutorSmokeTests.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/SpecializedExecutorSmokeTests.cs index b93d7862d5..977dbd4ad7 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/SpecializedExecutorSmokeTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/SpecializedExecutorSmokeTests.cs @@ -51,10 +51,10 @@ static ChatMessage ToMessage(string text) return result; } - public override AgentThread GetNewThread() + public override AgentThread GetNewThread(IAgentFeatureCollection? featureCollection = null) => new TestAgentThread(); - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) + public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, IAgentFeatureCollection? featureCollection = null) => new TestAgentThread(); public static TestAIAgent FromStrings(params string[] messages) => diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestEchoAgent.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestEchoAgent.cs index a77fc8a495..039e01a319 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestEchoAgent.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestEchoAgent.cs @@ -16,12 +16,12 @@ internal class TestEchoAgent(string? id = null, string? name = null, string? pre public override string Id => id ?? base.Id; public override string? Name => name ?? base.Name; - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) + public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, IAgentFeatureCollection? featureCollection = null) { return serializedThread.Deserialize(jsonSerializerOptions) ?? this.GetNewThread(); } - public override AgentThread GetNewThread() + public override AgentThread GetNewThread(IAgentFeatureCollection? featureCollection = null) { return new EchoAgentThread(); }