Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,19 @@ public override async Task<AgentRunResponse> RunAsync(IEnumerable<ChatMessage> m
throw new ArgumentException($"The provided thread is not of type {nameof(CustomAgentThread)}.", nameof(thread));
}

// Get existing messages from the store
var invokingContext = new ChatMessageStore.InvokingContext(messages);
var storeMessages = await typedThread.MessageStore.InvokingAsync(invokingContext, cancellationToken);

// Clone the input messages and turn them into response messages with upper case text.
List<ChatMessage> responseMessages = CloneAndToUpperCase(messages, this.DisplayName).ToList();

// Notify the thread of the input and output messages.
await typedThread.MessageStore.AddMessagesAsync(messages.Concat(responseMessages), cancellationToken);
var invokedContext = new ChatMessageStore.InvokedContext(messages, storeMessages)
{
ResponseMessages = responseMessages
};
await typedThread.MessageStore.InvokedAsync(invokedContext, cancellationToken);

return new AgentRunResponse
{
Expand All @@ -68,11 +76,19 @@ public override async IAsyncEnumerable<AgentRunResponseUpdate> RunStreamingAsync
throw new ArgumentException($"The provided thread is not of type {nameof(CustomAgentThread)}.", nameof(thread));
}

// Get existing messages from the store
var invokingContext = new ChatMessageStore.InvokingContext(messages);
var storeMessages = await typedThread.MessageStore.InvokingAsync(invokingContext, cancellationToken);

// Clone the input messages and turn them into response messages with upper case text.
List<ChatMessage> responseMessages = CloneAndToUpperCase(messages, this.DisplayName).ToList();

// Notify the thread of the input and output messages.
await typedThread.MessageStore.AddMessagesAsync(messages.Concat(responseMessages), cancellationToken);
var invokedContext = new ChatMessageStore.InvokedContext(messages, storeMessages)
{
ResponseMessages = responseMessages
};
await typedThread.MessageStore.InvokedAsync(invokedContext, cancellationToken);

foreach (var message in responseMessages)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,12 @@
.CreateAIAgent(new ChatClientAgentOptions
{
ChatOptions = new() { Instructions = "You are a helpful support specialist for Contoso Outdoors. Answer questions using the provided context and cite the source document when available." },
AIContextProviderFactory = ctx => new TextSearchProvider(SearchAdapter, ctx.SerializedState, ctx.JsonSerializerOptions, textSearchOptions)
AIContextProviderFactory = ctx => new TextSearchProvider(SearchAdapter, ctx.SerializedState, ctx.JsonSerializerOptions, textSearchOptions),
// Since we are using ChatCompletion which stores chat history locally, we can also add a message removal policy
// that removes messages produced by the TextSearchProvider before they are added to the chat history, so that
// we don't bloat chat history with all the search result messages.
ChatMessageStoreFactory = ctx => new InMemoryChatMessageStore(ctx.SerializedState, ctx.JsonSerializerOptions)
.WithAIContextProviderMessageRemoval(),
});

AgentThread thread = agent.GetNewThread();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,24 +88,7 @@ public VectorChatMessageStore(VectorStore vectorStore, JsonElement serializedSto

public string? ThreadDbKey { get; private set; }

public override async Task AddMessagesAsync(IEnumerable<ChatMessage> messages, CancellationToken cancellationToken = default)
{
this.ThreadDbKey ??= Guid.NewGuid().ToString("N");

var collection = this._vectorStore.GetCollection<string, ChatHistoryItem>("ChatHistory");
await collection.EnsureCollectionExistsAsync(cancellationToken);

await collection.UpsertAsync(messages.Select(x => new ChatHistoryItem()
{
Key = this.ThreadDbKey + x.MessageId,
Timestamp = DateTimeOffset.UtcNow,
ThreadId = this.ThreadDbKey,
SerializedMessage = JsonSerializer.Serialize(x),
MessageText = x.Text
}), cancellationToken);
}

public override async Task<IEnumerable<ChatMessage>> GetMessagesAsync(CancellationToken cancellationToken = default)
public override async ValueTask<IEnumerable<ChatMessage>> InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default)
{
var collection = this._vectorStore.GetCollection<string, ChatHistoryItem>("ChatHistory");
await collection.EnsureCollectionExistsAsync(cancellationToken);
Expand All @@ -123,6 +106,33 @@ public override async Task<IEnumerable<ChatMessage>> GetMessagesAsync(Cancellati
return messages;
}

public override async ValueTask InvokedAsync(InvokedContext context, CancellationToken cancellationToken = default)
{
// Don't store messages if the request failed.
if (context.InvokeException is not null)
{
return;
}

this.ThreadDbKey ??= Guid.NewGuid().ToString("N");

var collection = this._vectorStore.GetCollection<string, ChatHistoryItem>("ChatHistory");
await collection.EnsureCollectionExistsAsync(cancellationToken);

// Add both request and response messages to the store
// Optionally messages produced by the AIContextProvider can also be persisted (not shown).
var allNewMessages = context.RequestMessages.Concat(context.AIContextProviderMessages ?? []).Concat(context.ResponseMessages ?? []);

await collection.UpsertAsync(allNewMessages.Select(x => new ChatHistoryItem()
{
Key = this.ThreadDbKey + x.MessageId,
Timestamp = DateTimeOffset.UtcNow,
ThreadId = this.ThreadDbKey,
SerializedMessage = JsonSerializer.Serialize(x),
MessageText = x.Text
}), cancellationToken);
}

public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) =>
// We have to serialize the thread id, so that on deserialization we can retrieve the messages using the same thread id.
JsonSerializer.SerializeToElement(this.ThreadDbKey);
Expand Down
114 changes: 107 additions & 7 deletions dotnet/src/Microsoft.Agents.AI.Abstractions/ChatMessageStore.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@ namespace Microsoft.Agents.AI;
public abstract class ChatMessageStore
{
/// <summary>
/// Asynchronously retrieves all messages from the store that should be provided as context for the next agent invocation.
/// Called at the start of agent invocation to retrieve all messages from the store that should be provided as context for the next agent invocation.
/// </summary>
/// <param name="context">Contains the request context including the caller provided messages that will be used by the agent for this invocation.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>
/// A task that represents the asynchronous operation. The task result contains a collection of <see cref="ChatMessage"/>
Expand All @@ -59,20 +60,19 @@ public abstract class ChatMessageStore
/// and context management.
/// </para>
/// </remarks>
public abstract Task<IEnumerable<ChatMessage>> GetMessagesAsync(CancellationToken cancellationToken = default);
public abstract ValueTask<IEnumerable<ChatMessage>> InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default);

/// <summary>
/// Asynchronously adds new messages to the store.
/// Called at the end of the agent invocation to add new messages to the store.
/// </summary>
/// <param name="messages">The collection of chat messages to add to the store.</param>
/// <param name="context">Contains the invocation context including request messages, response messages, and any exception that occurred.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>A task that represents the asynchronous add operation.</returns>
/// <exception cref="ArgumentNullException"><paramref name="messages"/> is <see langword="null"/>.</exception>
/// <remarks>
/// <para>
/// Messages should be added in the order they were generated to maintain proper chronological sequence.
/// The store is responsible for preserving message ordering and ensuring that subsequent calls to
/// <see cref="GetMessagesAsync"/> return messages in the correct chronological order.
/// <see cref="InvokingAsync"/> return messages in the correct chronological order.
/// </para>
/// <para>
/// Implementations may perform additional processing during message addition, such as:
Expand All @@ -83,8 +83,12 @@ public abstract class ChatMessageStore
/// <item><description>Updating indices or search capabilities</description></item>
/// </list>
/// </para>
/// <para>
/// This method is called regardless of whether the invocation succeeded or failed.
/// To check if the invocation was successful, inspect the <see cref="InvokedContext.InvokeException"/> property.
/// </para>
/// </remarks>
public abstract Task AddMessagesAsync(IEnumerable<ChatMessage> messages, CancellationToken cancellationToken = default);
public abstract ValueTask InvokedAsync(InvokedContext context, CancellationToken cancellationToken = default);

/// <summary>
/// Serializes the current object's state to a <see cref="JsonElement"/> using the specified serialization options.
Expand Down Expand Up @@ -121,4 +125,100 @@ public abstract class ChatMessageStore
/// </remarks>
public TService? GetService<TService>(object? serviceKey = null)
=> this.GetService(typeof(TService), serviceKey) is TService service ? service : default;

/// <summary>
/// Contains the context information provided to <see cref="InvokingAsync(InvokingContext, CancellationToken)"/>.
/// </summary>
/// <remarks>
/// This class provides context about the invocation before the messages are retrieved from the store,
/// including the new messages that will be used. Stores can use this information to determine what
/// messages should be retrieved for the invocation.
/// </remarks>
public sealed class InvokingContext
{
/// <summary>
/// Initializes a new instance of the <see cref="InvokingContext"/> class with the specified request messages.
/// </summary>
/// <param name="requestMessages">The new messages to be used by the agent for this invocation.</param>
/// <exception cref="ArgumentNullException"><paramref name="requestMessages"/> is <see langword="null"/>.</exception>
public InvokingContext(IEnumerable<ChatMessage> requestMessages)
{
this.RequestMessages = requestMessages ?? throw new ArgumentNullException(nameof(requestMessages));
}

/// <summary>
/// Gets the caller provided messages that will be used by the agent for this invocation.
/// </summary>
/// <value>
/// A collection of <see cref="ChatMessage"/> instances representing new messages that were provided by the caller.
/// </value>
public IEnumerable<ChatMessage> RequestMessages { get; }
}

/// <summary>
/// Contains the context information provided to <see cref="InvokedAsync(InvokedContext, CancellationToken)"/>.
/// </summary>
/// <remarks>
/// This class provides context about a completed agent invocation, including both the
/// request messages that were used and the response messages that were generated. It also indicates
/// whether the invocation succeeded or failed.
/// </remarks>
public sealed class InvokedContext
{
/// <summary>
/// Initializes a new instance of the <see cref="InvokedContext"/> class with the specified request messages.
/// </summary>
/// <param name="requestMessages">The caller provided messages that were used by the agent for this invocation.</param>
/// <param name="chatMessageStoreMessages">The messages retrieved from the <see cref="ChatMessageStore"/> for this invocation.</param>
/// <exception cref="ArgumentNullException"><paramref name="requestMessages"/> is <see langword="null"/>.</exception>
public InvokedContext(IEnumerable<ChatMessage> requestMessages, IEnumerable<ChatMessage> chatMessageStoreMessages)
{
this.RequestMessages = Throw.IfNull(requestMessages);
this.ChatMessageStoreMessages = chatMessageStoreMessages;
}

/// <summary>
/// Gets the caller provided messages that were used by the agent for this invocation.
/// </summary>
/// <value>
/// A collection of <see cref="ChatMessage"/> instances representing new messages that were provided by the caller.
/// This does not include any <see cref="ChatMessageStore"/> supplied messages.
/// </value>
public IEnumerable<ChatMessage> RequestMessages { get; }

/// <summary>
/// Gets the messages retrieved from the <see cref="ChatMessageStore"/> for this invocation, if any.
/// </summary>
/// <value>
/// A collection of <see cref="ChatMessage"/> instances that were retrieved from the <see cref="ChatMessageStore"/>,
/// and were used by the agent as part of the invocation.
/// </value>
public IEnumerable<ChatMessage> ChatMessageStoreMessages { get; }

/// <summary>
/// Gets or sets the messages provided by the <see cref="AIContextProvider"/> for this invocation, if any.
/// </summary>
/// <value>
/// A collection of <see cref="ChatMessage"/> instances that were provided by the <see cref="AIContextProvider"/>,
/// and were used by the agent as part of the invocation.
/// </value>
public IEnumerable<ChatMessage>? AIContextProviderMessages { get; set; }

/// <summary>
/// Gets the collection of response messages generated during this invocation if the invocation succeeded.
/// </summary>
/// <value>
/// A collection of <see cref="ChatMessage"/> instances representing the response,
/// or <see langword="null"/> if the invocation failed or did not produce response messages.
/// </value>
public IEnumerable<ChatMessage>? ResponseMessages { get; set; }

/// <summary>
/// Gets the <see cref="Exception"/> that was thrown during the invocation, if the invocation failed.
/// </summary>
/// <value>
/// The exception that caused the invocation to fail, or <see langword="null"/> if the invocation succeeded.
/// </value>
public Exception? InvokeException { get; set; }
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Collections.Generic;
using Microsoft.Extensions.AI;

namespace Microsoft.Agents.AI;

/// <summary>
/// Contains extension methods for the <see cref="ChatMessageStore"/> class.
/// </summary>
public static class ChatMessageStoreExtensions
{
/// <summary>
/// Adds message filtering to an existing store, so that messages passed to the store and messages produced by the store
/// can be filtered, updated or replaced.
/// </summary>
/// <param name="store">The store to add the message filter to.</param>
/// <param name="invokingMessagesFilter">An optional filter function to apply to messages before they are invoked. If null, no filter is applied at this
/// stage.</param>
/// <param name="invokedMessagesFilter">An optional filter function to apply to the invocation context after messages have been invoked. If null, no
/// filter is applied at this stage.</param>
/// <returns>The <see cref="ChatMessageStore"/> with filtering applied.</returns>
public static ChatMessageStore WithMessageFilters(
this ChatMessageStore store,
Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>>? invokingMessagesFilter = null,
Func<ChatMessageStore.InvokedContext, ChatMessageStore.InvokedContext>? invokedMessagesFilter = null)
{
return new ChatMessageStoreMessageFilter(
innerChatMessageStore: store,
invokingMessagesFilter: invokingMessagesFilter,
invokedMessagesFilter: invokedMessagesFilter);
}

/// <summary>
/// Decorates the provided chat message store so that it does not store messages produced by any <see cref="AIContextProvider"/>.
/// </summary>
/// <param name="store">The store to add the message filter to.</param>
/// <returns>A new <see cref="ChatMessageStore"/> instance that filters out <see cref="AIContextProvider"/> messages so they do not get stored.</returns>
public static ChatMessageStore WithAIContextProviderMessageRemoval(this ChatMessageStore store)
{
return new ChatMessageStoreMessageFilter(
innerChatMessageStore: store,
invokedMessagesFilter: (ctx) =>
{
ctx.AIContextProviderMessages = null;
return ctx;
});
}
}
Loading
Loading