diff --git a/src/Common/Experimentals.cs b/src/Common/Experimentals.cs index c81ef981e..b12f11a4b 100644 --- a/src/Common/Experimentals.cs +++ b/src/Common/Experimentals.cs @@ -24,4 +24,7 @@ internal static class Experimentals // public const string Tasks_DiagnosticId = "MCP5001"; // public const string Tasks_Message = "The Tasks feature is experimental within specification version 2025-11-25 and is subject to change. See SEP-1686 for more information."; // public const string Tasks_Url = "https://github.com/modelcontextprotocol/modelcontextprotocol/issues/1686"; + + public const string UseMcpClient_DiagnosticId = "MCP5002"; + public const string UseMcpClient_Message = "The UseMcpClient middleware for integrating hosted MCP servers with IChatClient is experimental and subject to change."; } diff --git a/src/ModelContextProtocol/McpChatClientBuilderExtensions.cs b/src/ModelContextProtocol/McpChatClientBuilderExtensions.cs new file mode 100644 index 000000000..82f2cd45d --- /dev/null +++ b/src/ModelContextProtocol/McpChatClientBuilderExtensions.cs @@ -0,0 +1,285 @@ +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; +#pragma warning disable MEAI001 // Type is for evaluation purposes only and is subject to change or removal in future updates. + +namespace ModelContextProtocol.Client; + +/// +/// Extension methods for adding MCP client support to chat clients. +/// +public static class McpChatClientBuilderExtensions +{ + /// + /// Adds a chat client to the chat client pipeline that creates an for each + /// in and augments it with the tools from MCP servers as instances. + /// + /// The to configure. + /// The to use, or to create a new instance. + /// The to use, or to resolve from services. + /// An optional callback to configure the for each . + /// The for method chaining. + /// + /// + /// When a HostedMcpServerTool is encountered in the tools collection, the client + /// connects to the MCP server, retrieves available tools, and expands them into callable AI functions. + /// Connections are cached by server address to avoid redundant connections. + /// + /// + /// Use this method as an alternative when working with chat providers that don't have built-in support for hosted MCP servers. + /// + /// + [Experimental(Experimentals.UseMcpClient_DiagnosticId)] + public static ChatClientBuilder UseMcpClient( + this ChatClientBuilder builder, + HttpClient? httpClient = null, + ILoggerFactory? loggerFactory = null, + Action? configureTransportOptions = null) + { + return builder.Use((innerClient, services) => + { + loggerFactory ??= (ILoggerFactory)services.GetService(typeof(ILoggerFactory))!; + var chatClient = new McpChatClient(innerClient, httpClient, loggerFactory, configureTransportOptions); + return chatClient; + }); + } + + private sealed class McpChatClient : DelegatingChatClient + { + private readonly ILoggerFactory? _loggerFactory; + private readonly ILogger _logger; + private readonly HttpClient _httpClient; + private readonly bool _ownsHttpClient; + private readonly McpClientTasksLruCache _lruCache; + private readonly Action? _configureTransportOptions; + + /// + /// Initializes a new instance of the class. + /// + /// The underlying , or the next instance in a chain of clients. + /// An optional to use when connecting to MCP servers. If not provided, a new instance will be created. + /// An to use for logging information about function invocation. + /// An optional callback to configure the for each . + public McpChatClient(IChatClient innerClient, HttpClient? httpClient = null, ILoggerFactory? loggerFactory = null, Action? configureTransportOptions = null) + : base(innerClient) + { + _loggerFactory = loggerFactory; + _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance; + _httpClient = httpClient ?? new HttpClient(); + _ownsHttpClient = httpClient is null; + _lruCache = new McpClientTasksLruCache(capacity: 20); + _configureTransportOptions = configureTransportOptions; + } + + public override async Task GetResponseAsync( + IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) + { + if (options?.Tools is { Count: > 0 }) + { + var downstreamTools = await BuildDownstreamAIToolsAsync(options.Tools).ConfigureAwait(false); + options = options.Clone(); + options.Tools = downstreamTools; + } + + return await base.GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false); + } + + public override async IAsyncEnumerable GetStreamingResponseAsync(IEnumerable messages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + if (options?.Tools is { Count: > 0 }) + { + var downstreamTools = await BuildDownstreamAIToolsAsync(options.Tools).ConfigureAwait(false); + options = options.Clone(); + options.Tools = downstreamTools; + } + + await foreach (var update in base.GetStreamingResponseAsync(messages, options, cancellationToken).ConfigureAwait(false)) + { + yield return update; + } + } + + private async Task> BuildDownstreamAIToolsAsync(IList chatOptionsTools) + { + List downstreamTools = []; + foreach (var chatOptionsTool in chatOptionsTools) + { + if (chatOptionsTool is not HostedMcpServerTool hostedMcpTool) + { + // For other tools, we want to keep them in the list of tools. + downstreamTools.Add(chatOptionsTool); + continue; + } + + if (!Uri.TryCreate(hostedMcpTool.ServerAddress, UriKind.Absolute, out var parsedAddress) || + (parsedAddress.Scheme != Uri.UriSchemeHttp && parsedAddress.Scheme != Uri.UriSchemeHttps)) + { + throw new InvalidOperationException( + $"Invalid http(s) address: '{hostedMcpTool.ServerAddress}'. MCP server address must be an absolute http(s) URL."); + } + + // Get MCP client and its tools from cache (both are fetched together on first access). + var (_, mcpTools) = await GetClientAndToolsAsync(hostedMcpTool, parsedAddress).ConfigureAwait(false); + + // Add the listed functions to our list of tools we'll pass to the inner client. + foreach (var mcpTool in mcpTools) + { + if (hostedMcpTool.AllowedTools is not null && !hostedMcpTool.AllowedTools.Contains(mcpTool.Name)) + { + if (_logger.IsEnabled(LogLevel.Information)) + { + _logger.LogInformation("MCP function '{FunctionName}' is not allowed by the tool configuration.", mcpTool.Name); + } + continue; + } + + var wrappedFunction = new McpRetriableAIFunction(mcpTool, hostedMcpTool, parsedAddress, this); + + switch (hostedMcpTool.ApprovalMode) + { + case HostedMcpServerToolNeverRequireApprovalMode: + case HostedMcpServerToolRequireSpecificApprovalMode specificApprovalMode when specificApprovalMode.NeverRequireApprovalToolNames?.Contains(mcpTool.Name) is true: + downstreamTools.Add(wrappedFunction); + break; + + default: + // Default to always require approval if no specific mode is set. + downstreamTools.Add(new ApprovalRequiredAIFunction(wrappedFunction)); + break; + } + } + } + + return downstreamTools; + } + + protected override void Dispose(bool disposing) + { + if (disposing) + { + if (_ownsHttpClient) + { + _httpClient?.Dispose(); + } + + _lruCache.Dispose(); + } + + base.Dispose(disposing); + } + + internal async Task<(McpClient Client, IList Tools)> GetClientAndToolsAsync(HostedMcpServerTool hostedMcpTool, Uri serverAddressUri) + { + // Note: We don't pass cancellationToken to the factory because the cached task should not be tied to any single caller's cancellation token. + // Instead, callers can cancel waiting for the task, but the connection attempt itself will complete independently. + Task<(McpClient, IList Tools)> task = _lruCache.GetOrAdd( + hostedMcpTool.ServerAddress, + static (_, state) => state.self.CreateMcpClientAndToolsAsync(state.hostedMcpTool, state.serverAddressUri, CancellationToken.None), + (self: this, hostedMcpTool, serverAddressUri)); + + try + { + return await task.ConfigureAwait(false); + } + catch + { + bool result = RemoveMcpClientFromCache(hostedMcpTool.ServerAddress, out var removedTask); + Debug.Assert(result && removedTask!.Status != TaskStatus.RanToCompletion); + throw; + } + } + + private async Task<(McpClient Client, IList Tools)> CreateMcpClientAndToolsAsync(HostedMcpServerTool hostedMcpTool, Uri serverAddressUri, CancellationToken cancellationToken) + { + var transportOptions = new HttpClientTransportOptions + { + Endpoint = serverAddressUri, + Name = hostedMcpTool.ServerName, + AdditionalHeaders = hostedMcpTool.AuthorizationToken is not null + // Update to pass all headers once https://github.com/dotnet/extensions/pull/7053 is available. + ? new Dictionary() { { "Authorization", $"Bearer {hostedMcpTool.AuthorizationToken}" } } + : null, + }; + + _configureTransportOptions?.Invoke(new DummyHostedMcpServerTool(hostedMcpTool.ServerName, serverAddressUri), transportOptions); + + var transport = new HttpClientTransport(transportOptions, _httpClient, _loggerFactory); + var client = await McpClient.CreateAsync(transport, cancellationToken: cancellationToken).ConfigureAwait(false); + try + { + var tools = await client.ListToolsAsync(cancellationToken: cancellationToken).ConfigureAwait(false); + return (client, tools); + } + catch + { + try + { + await client.DisposeAsync().ConfigureAwait(false); + } + catch { } // allow the original exception to propagate + + throw; + } + } + + internal bool RemoveMcpClientFromCache(string key, out Task<(McpClient Client, IList Tools)>? removedTask) + => _lruCache.TryRemove(key, out removedTask); + + /// + /// A temporary instance passed to the configureTransportOptions callback. + /// This prevents the callback from modifying the original tool instance. + /// + private sealed class DummyHostedMcpServerTool(string serverName, Uri serverAddress) + : HostedMcpServerTool(serverName, serverAddress); + } + + /// + /// An AI function wrapper that retries the invocation by recreating an MCP client when an occurs. + /// For example, this can happen if a session is revoked or a server error occurs. The retry evicts the cached MCP client. + /// + private sealed class McpRetriableAIFunction : DelegatingAIFunction + { + private readonly HostedMcpServerTool _hostedMcpTool; + private readonly Uri _serverAddressUri; + private readonly McpChatClient _chatClient; + + public McpRetriableAIFunction(AIFunction innerFunction, HostedMcpServerTool hostedMcpTool, Uri serverAddressUri, McpChatClient chatClient) + : base(innerFunction) + { + _hostedMcpTool = hostedMcpTool; + _serverAddressUri = serverAddressUri; + _chatClient = chatClient; + } + + protected override async ValueTask InvokeCoreAsync(AIFunctionArguments arguments, CancellationToken cancellationToken) + { + try + { + return await base.InvokeCoreAsync(arguments, cancellationToken).ConfigureAwait(false); + } + catch (HttpRequestException) { } + + bool result = _chatClient.RemoveMcpClientFromCache(_hostedMcpTool.ServerAddress, out var removedTask); + Debug.Assert(result && removedTask!.Status == TaskStatus.RanToCompletion); + _ = removedTask!.Result.Client.DisposeAsync().AsTask(); + + var freshTool = await GetCurrentToolAsync().ConfigureAwait(false); + return await freshTool.InvokeAsync(arguments, cancellationToken).ConfigureAwait(false); + } + + private async Task GetCurrentToolAsync() + { + Debug.Assert(Uri.TryCreate(_hostedMcpTool.ServerAddress, UriKind.Absolute, out var parsedAddress) && + (parsedAddress.Scheme == Uri.UriSchemeHttp || parsedAddress.Scheme == Uri.UriSchemeHttps), + "Server address should have been validated before construction"); + + var (_, tools) = await _chatClient.GetClientAndToolsAsync(_hostedMcpTool, _serverAddressUri!).ConfigureAwait(false); + + return tools.FirstOrDefault(t => t.Name == Name) ?? + throw new McpProtocolException($"Tool '{Name}' no longer exists on the MCP server.", McpErrorCode.InvalidParams); + } + } +} diff --git a/src/ModelContextProtocol/McpClientTasksLruCache.cs b/src/ModelContextProtocol/McpClientTasksLruCache.cs new file mode 100644 index 000000000..5646aad37 --- /dev/null +++ b/src/ModelContextProtocol/McpClientTasksLruCache.cs @@ -0,0 +1,88 @@ +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; + +namespace ModelContextProtocol.Client; + +/// +/// A thread-safe Least Recently Used (LRU) cache for MCP client and tools. +/// +internal sealed class McpClientTasksLruCache : IDisposable +{ + private readonly Dictionary Node, Task<(McpClient Client, IList Tools)> Task)> _cache; + private readonly LinkedList _lruList; + private readonly object _lock = new(); + private readonly int _capacity; + + public McpClientTasksLruCache(int capacity) + { + Debug.Assert(capacity > 0); + _capacity = capacity; + _cache = new Dictionary, Task<(McpClient, IList)>)>(capacity); + _lruList = []; + } + + public Task<(McpClient Client, IList Tools)> GetOrAdd(string key, Func)>> valueFactory, TState state) + { + lock (_lock) + { + if (_cache.TryGetValue(key, out var existing)) + { + _lruList.Remove(existing.Node); + _lruList.AddLast(existing.Node); + return existing.Task; + } + + var value = valueFactory(key, state); + var newNode = _lruList.AddLast(key); + _cache[key] = (newNode, value); + + // Evict oldest if over capacity + if (_cache.Count > _capacity) + { + string oldestKey = _lruList.First!.Value; + _lruList.RemoveFirst(); + (_, Task<(McpClient Client, IList Tools)> task) = _cache[oldestKey]; + _cache.Remove(oldestKey); + + // Dispose evicted MCP client + if (task.Status == TaskStatus.RanToCompletion) + { + _ = task.Result.Client.DisposeAsync().AsTask(); + } + } + + return value; + } + } + + public bool TryRemove(string key, [MaybeNullWhen(false)] out Task<(McpClient Client, IList Tools)>? task) + { + lock (_lock) + { + if (_cache.TryGetValue(key, out var entry)) + { + _cache.Remove(key); + _lruList.Remove(entry.Node); + task = entry.Task; + return true; + } + + task = null; + return false; + } + } + + public void Dispose() + { + lock (_lock) + { + foreach ((_, Task<(McpClient Client, IList Tools)> task) in _cache.Values) + { + if (task.Status == TaskStatus.RanToCompletion) + { + _ = task.Result.Client.DisposeAsync().AsTask(); + } + } + } + } +} diff --git a/src/ModelContextProtocol/ModelContextProtocol.csproj b/src/ModelContextProtocol/ModelContextProtocol.csproj index b69108ab2..a4f9fc9c4 100644 --- a/src/ModelContextProtocol/ModelContextProtocol.csproj +++ b/src/ModelContextProtocol/ModelContextProtocol.csproj @@ -15,6 +15,7 @@ + @@ -23,6 +24,7 @@ + diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/UseMcpClientTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/UseMcpClientTests.cs new file mode 100644 index 000000000..6d9a6bf3a --- /dev/null +++ b/tests/ModelContextProtocol.AspNetCore.Tests/UseMcpClientTests.cs @@ -0,0 +1,802 @@ +using System.Runtime.CompilerServices; +using System.Text.Json; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; +using ModelContextProtocol.AspNetCore.Tests.Utils; +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol; +using Moq; +#pragma warning disable MCP5002 +#pragma warning disable MEAI001 // Type is for evaluation purposes only and is subject to change or removal in future updates. + +namespace ModelContextProtocol.AspNetCore.Tests; + +public class UseMcpClientTests : KestrelInMemoryTest +{ + public UseMcpClientTests(ITestOutputHelper testOutputHelper) + : base(testOutputHelper) + { + } + + private async Task StartServerAsync(Action? configureApp = null) + { + IMcpServerBuilder builder = Builder.Services.AddMcpServer(options => + { + options.Capabilities = new ServerCapabilities + { + Tools = new(), + Resources = new(), + Prompts = new(), + }; + options.ServerInstructions = "This is a test server with only stub functionality"; + options.Handlers = new() + { + ListToolsHandler = async (request, cancellationToken) => + { + return new ListToolsResult + { + Tools = + [ + new Tool + { + Name = "echo", + Description = "Echoes the input back to the client.", + InputSchema = JsonElement.Parse(""" + { + "type": "object", + "properties": { + "message": { + "type": "string", + "description": "The input to echo back." + } + }, + "required": ["message"] + } + """), + }, + new Tool + { + Name = "echoSessionId", + Description = "Echoes the session id back to the client.", + InputSchema = JsonElement.Parse(""" + { + "type": "object" + } + """), + }, + new Tool + { + Name = "sampleLLM", + Description = "Samples from an LLM using MCP's sampling feature.", + InputSchema = JsonElement.Parse(""" + { + "type": "object", + "properties": { + "prompt": { + "type": "string", + "description": "The prompt to send to the LLM" + }, + "maxTokens": { + "type": "number", + "description": "Maximum number of tokens to generate" + } + }, + "required": ["prompt", "maxTokens"] + } + """), + } + ] + }; + }, + CallToolHandler = async (request, cancellationToken) => + { + if (request.Params is null) + { + throw new McpProtocolException("Missing required parameter 'name'", McpErrorCode.InvalidParams); + } + if (request.Params.Name == "echo") + { + if (request.Params.Arguments is null || !request.Params.Arguments.TryGetValue("message", out var message)) + { + throw new McpProtocolException("Missing required argument 'message'", McpErrorCode.InvalidParams); + } + return new CallToolResult + { + Content = [new TextContentBlock { Text = $"Echo: {message}" }] + }; + } + else if (request.Params.Name == "echoSessionId") + { + return new CallToolResult + { + Content = [new TextContentBlock { Text = request.Server.SessionId ?? string.Empty }] + }; + } + else if (request.Params.Name == "sampleLLM") + { + if (request.Params.Arguments is null || + !request.Params.Arguments.TryGetValue("prompt", out var prompt) || + !request.Params.Arguments.TryGetValue("maxTokens", out var maxTokens)) + { + throw new McpProtocolException("Missing required arguments 'prompt' and 'maxTokens'", McpErrorCode.InvalidParams); + } + // Simple mock response for sampleLLM + return new CallToolResult + { + Content = [new TextContentBlock { Text = "LLM sampling result: Test response" }] + }; + } + else + { + throw new McpProtocolException($"Unknown tool: '{request.Params.Name}'", McpErrorCode.InvalidParams); + } + } + }; + }) + .WithHttpTransport(); + + var app = Builder.Build(); + configureApp?.Invoke(app); + app.MapMcp(); + await app.StartAsync(TestContext.Current.CancellationToken); + return app; + } + + /// + /// Captures the arguments received by the leaf mock IChatClient. + /// + private sealed class LeafChatClientState + { + public ChatOptions? CapturedOptions { get; set; } + public List> CapturedMessages { get; set; } = []; + public int CallCount { get; set; } + public void Clear() + { + CapturedOptions = null; + CapturedMessages.Clear(); + CallCount = 0; + } + } + + private IChatClient CreateTestChatClient(out LeafChatClientState leafClientState, Action? configureTransportOptions = null) + { + var state = new LeafChatClientState(); + + var mockInnerClient = new Mock(); + mockInnerClient + .Setup(c => c.GetResponseAsync( + It.IsAny>(), + It.IsAny(), + It.IsAny())) + .Returns((IEnumerable messages, ChatOptions? options, CancellationToken ct) => + GetStreamingResponseAsync(messages, options, ct).ToChatResponseAsync(ct)); + + mockInnerClient + .Setup(c => c.GetStreamingResponseAsync( + It.IsAny>(), + It.IsAny(), + It.IsAny())) + .Returns(GetStreamingResponseAsync); + + leafClientState = state; + return mockInnerClient.Object.AsBuilder() + .UseMcpClient(HttpClient, LoggerFactory, configureTransportOptions) + // Placement is important, must be after UseMcpClient, otherwise, UseFunctionInvocation won't see the MCP tools. + .UseFunctionInvocation() + .Build(); + + async IAsyncEnumerable GetStreamingResponseAsync( + IEnumerable messages, + ChatOptions? options, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + state.CapturedOptions = options; + state.CapturedMessages.Add(messages); + + // First call: request to invoke the echo tool + if (state.CallCount++ == 0 && options?.Tools is { Count: > 0 } tools) + { + Assert.Contains(tools, t => t.Name == "echo"); + yield return new ChatResponseUpdate(ChatRole.Assistant, + [ + new FunctionCallContent("call_123", "echo", new Dictionary { ["message"] = "test message" }) + ]); + } + else + { + // Subsequent calls: return final response + yield return new ChatResponseUpdate(ChatRole.Assistant, "Final response"); + } + } + } + + private static void AssertLeafClientMessagesWithInvocation(List> capturedMessages) + { + Assert.Equal(2, capturedMessages.Count); + var firstCall = capturedMessages[0]; + var msg = Assert.Single(firstCall); + Assert.Equal(ChatRole.User, msg.Role); + Assert.Equal("Test message", msg.Text); + + var secondCall = capturedMessages[1].ToList(); + Assert.Equal(3, secondCall.Count); + Assert.Equal(ChatRole.User, secondCall[0].Role); + Assert.Equal("Test message", secondCall[0].Text); + + Assert.Equal(ChatRole.Assistant, secondCall[1].Role); + var functionCall = Assert.IsType(Assert.Single(secondCall[1].Contents)); + Assert.Equal("call_123", functionCall.CallId); + Assert.Equal("echo", functionCall.Name); + + Assert.Equal(ChatRole.Tool, secondCall[2].Role); + var functionResult = Assert.IsType(Assert.Single(secondCall[2].Contents)); + Assert.Equal("call_123", functionResult.CallId); + Assert.Contains("Echo: test message", functionResult.Result?.ToString()); + } + + private static void AssertResponseWithInvocation(ChatResponse response) + { + Assert.NotNull(response); + Assert.Equal(3, response.Messages.Count); + + Assert.Equal(ChatRole.Assistant, response.Messages[0].Role); + Assert.Single(response.Messages[0].Contents); + Assert.IsType(response.Messages[0].Contents[0]); + + Assert.Equal(ChatRole.Tool, response.Messages[1].Role); + Assert.Single(response.Messages[1].Contents); + Assert.IsType(response.Messages[1].Contents[0]); + + Assert.Equal(ChatRole.Assistant, response.Messages[2].Role); + Assert.Equal("Final response", response.Messages[2].Text); + } + + [Theory] + [InlineData(false, false)] + [InlineData(false, true)] + [InlineData(true, false)] + [InlineData(true, true)] + public async Task UseMcpClient_ShouldProduceTools(bool streaming, bool useUrl) + { + // Arrange + await using var _ = await StartServerAsync(); + using IChatClient sut = CreateTestChatClient(out var leafClientState); + var mcpTool = useUrl ? + new HostedMcpServerTool("serverName", HttpClient.BaseAddress!) : + new HostedMcpServerTool("serverName", HttpClient.BaseAddress!.ToString()); + mcpTool.ApprovalMode = HostedMcpServerToolApprovalMode.NeverRequire; + var options = new ChatOptions { Tools = [mcpTool] }; + + // Act + var response = streaming ? + await sut.GetStreamingResponseAsync("Test message", options, TestContext.Current.CancellationToken).ToChatResponseAsync(TestContext.Current.CancellationToken) : + await sut.GetResponseAsync("Test message", options, TestContext.Current.CancellationToken); + + // Assert + AssertResponseWithInvocation(response); + AssertLeafClientMessagesWithInvocation(leafClientState.CapturedMessages); + Assert.NotNull(leafClientState.CapturedOptions); + Assert.NotNull(leafClientState.CapturedOptions.Tools); + var toolNames = leafClientState.CapturedOptions.Tools.Select(t => t.Name).ToList(); + Assert.Equal(3, toolNames.Count); + Assert.Contains("echo", toolNames); + Assert.Contains("echoSessionId", toolNames); + Assert.Contains("sampleLLM", toolNames); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task UseMcpClient_DoesNotConflictWithRegularTools(bool streaming) + { + // Arrange + await using var _ = await StartServerAsync(); + using IChatClient sut = CreateTestChatClient(out var leafClientState); + var regularTool = AIFunctionFactory.Create(() => "regular tool result", "regularTool"); + var mcpTool = new HostedMcpServerTool("serverName", HttpClient.BaseAddress!.ToString()) + { + ApprovalMode = HostedMcpServerToolApprovalMode.NeverRequire + }; + var options = new ChatOptions + { + Tools = [regularTool, mcpTool] + }; + + // Act + var response = streaming ? + await sut.GetStreamingResponseAsync("Test message", options, TestContext.Current.CancellationToken).ToChatResponseAsync(TestContext.Current.CancellationToken) : + await sut.GetResponseAsync("Test message", options, TestContext.Current.CancellationToken); + + // Assert + AssertResponseWithInvocation(response); + AssertLeafClientMessagesWithInvocation(leafClientState.CapturedMessages); + Assert.NotNull(leafClientState.CapturedOptions); + Assert.NotNull(leafClientState.CapturedOptions.Tools); + var toolNames = leafClientState.CapturedOptions.Tools.Select(t => t.Name).ToList(); + Assert.Equal(4, toolNames.Count); + Assert.Contains("regularTool", toolNames); + Assert.Contains("echo", toolNames); + Assert.Contains("echoSessionId", toolNames); + Assert.Contains("sampleLLM", toolNames); + } + + public static IEnumerable UseMcpClient_ApprovalMode_TestData() + { + string[] allToolNames = ["echo", "echoSessionId", "sampleLLM"]; + foreach (var streaming in new[] { false, true }) + { + yield return new object?[] { streaming, new HostedMcpServerToolNeverRequireApprovalMode(), (string[])[], allToolNames }; + yield return new object?[] { streaming, new HostedMcpServerToolAlwaysRequireApprovalMode(), allToolNames, (string[])[] }; + yield return new object?[] { streaming, null, allToolNames, (string[])[] }; + // Specific mode with empty lists - all tools should default to requiring approval. + yield return new object?[] { streaming, new HostedMcpServerToolRequireSpecificApprovalMode([], []), allToolNames, (string[])[] }; + // Specific mode with one tool always requiring approval - the other two should default to requiring approval. + yield return new object?[] { streaming, new HostedMcpServerToolRequireSpecificApprovalMode(["echo"], []), allToolNames, (string[])[] }; + // Specific mode with one tool never requiring approval - the other two should default to requiring approval. + yield return new object?[] { streaming, new HostedMcpServerToolRequireSpecificApprovalMode([], ["echo"]), (string[])["echoSessionId", "sampleLLM"], (string[])["echo"] }; + } + } + + [Theory] + [MemberData(nameof(UseMcpClient_ApprovalMode_TestData))] + public async Task UseMcpClient_ApprovalMode(bool streaming, HostedMcpServerToolApprovalMode? approvalMode, string[] expectedApprovalRequiredAIFunctions, string[] expectedNormalAIFunctions) + { + // Arrange + await using var _ = await StartServerAsync(); + using IChatClient sut = CreateTestChatClient(out var leafClientState); + var mcpTool = new HostedMcpServerTool("serverName", HttpClient.BaseAddress!) + { + ApprovalMode = approvalMode + }; + var options = new ChatOptions { Tools = [mcpTool] }; + + // Act + var response = streaming ? + await sut.GetStreamingResponseAsync("Test message", options, TestContext.Current.CancellationToken).ToChatResponseAsync(TestContext.Current.CancellationToken) : + await sut.GetResponseAsync("Test message", options, TestContext.Current.CancellationToken); + + // Assert + Assert.NotNull(leafClientState.CapturedOptions); + Assert.NotNull(leafClientState.CapturedOptions.Tools); + Assert.Equal(3, leafClientState.CapturedOptions.Tools.Count); + + var toolsRequiringApproval = leafClientState.CapturedOptions.Tools + .Where(t => t is ApprovalRequiredAIFunction).Select(t => t.Name); + var toolsNotRequiringApproval = leafClientState.CapturedOptions.Tools + .Where(t => t is not ApprovalRequiredAIFunction).Select(t => t.Name); + + Assert.Equivalent(expectedApprovalRequiredAIFunctions, toolsRequiringApproval); + Assert.Equivalent(expectedNormalAIFunctions, toolsNotRequiringApproval); + } + + public static IEnumerable UseMcpClient_HandleFunctionApprovalRequest_TestData() + { + foreach (var streaming in new[] { false, true }) + { + // Approval modes that will cause function approval requests + yield return new object?[] { streaming, null }; + yield return new object?[] { streaming, HostedMcpServerToolApprovalMode.AlwaysRequire }; + yield return new object?[] { streaming, HostedMcpServerToolApprovalMode.RequireSpecific(["echo"], null) }; + } + } + + [Theory] + [MemberData(nameof(UseMcpClient_HandleFunctionApprovalRequest_TestData))] + public async Task UseMcpClient_HandleFunctionApprovalRequest(bool streaming, HostedMcpServerToolApprovalMode? approvalMode) + { + // Arrange + await using var _ = await StartServerAsync(); + using IChatClient sut = CreateTestChatClient(out var leafClientState); + var mcpTool = new HostedMcpServerTool("serverName", HttpClient.BaseAddress!) + { + ApprovalMode = approvalMode + }; + var options = new ChatOptions { Tools = [mcpTool] }; + + // Act + List chatHistory = []; + chatHistory.Add(new ChatMessage(ChatRole.User, "Test message")); + var response = streaming ? + await sut.GetStreamingResponseAsync(chatHistory, options, TestContext.Current.CancellationToken).ToChatResponseAsync(TestContext.Current.CancellationToken) : + await sut.GetResponseAsync(chatHistory, options, TestContext.Current.CancellationToken); + + chatHistory.AddRange(response.Messages); + var approvalRequest = Assert.Single(response.Messages.SelectMany(m => m.Contents).OfType()); + chatHistory.Add(new ChatMessage(ChatRole.User, [approvalRequest.CreateResponse(true)])); + + response = streaming ? + await sut.GetStreamingResponseAsync(chatHistory, options, TestContext.Current.CancellationToken).ToChatResponseAsync(TestContext.Current.CancellationToken) : + await sut.GetResponseAsync(chatHistory, options, TestContext.Current.CancellationToken); + + // Assert + AssertResponseWithInvocation(response); + AssertLeafClientMessagesWithInvocation(leafClientState.CapturedMessages); + } + + [Theory] + [InlineData(false, null, (string[])["echo", "echoSessionId", "sampleLLM"])] + [InlineData(true, null, (string[])["echo", "echoSessionId", "sampleLLM"])] + [InlineData(false, (string[])["echo"], (string[])["echo"])] + [InlineData(true, (string[])["echo"], (string[])["echo"])] + [InlineData(false, (string[])[], (string[])[])] + [InlineData(true, (string[])[], (string[])[])] + public async Task UseMcpClient_AllowedTools_FiltersCorrectly(bool streaming, string[]? allowedTools, string[] expectedTools) + { + // Arrange + await using var _ = await StartServerAsync(); + using IChatClient sut = CreateTestChatClient(out var leafClientState); + var mcpTool = new HostedMcpServerTool("serverName", HttpClient.BaseAddress!) + { + AllowedTools = allowedTools, + ApprovalMode = HostedMcpServerToolApprovalMode.NeverRequire + }; + var options = new ChatOptions { Tools = [mcpTool] }; + + // Act + var response = streaming ? + await sut.GetStreamingResponseAsync("Test message", options, TestContext.Current.CancellationToken).ToChatResponseAsync(TestContext.Current.CancellationToken) : + await sut.GetResponseAsync("Test message", options, TestContext.Current.CancellationToken); + + // Assert + Assert.NotNull(leafClientState.CapturedOptions); + Assert.NotNull(leafClientState.CapturedOptions.Tools); + var toolNames = leafClientState.CapturedOptions.Tools.Select(t => t.Name).ToList(); + Assert.Equal(expectedTools.Length, toolNames.Count); + Assert.Equivalent(expectedTools, toolNames); + + if (expectedTools.Contains("echo")) + { + AssertResponseWithInvocation(response); + AssertLeafClientMessagesWithInvocation(leafClientState.CapturedMessages); + } + else + { + var responseMsg = Assert.Single(response.Messages); + Assert.Equal(ChatRole.Assistant, responseMsg.Role); + Assert.Equal("Final response", responseMsg.Text); + + Assert.Single(leafClientState.CapturedMessages); + var firstCall = leafClientState.CapturedMessages[0]; + var leafClientMessage = Assert.Single(firstCall); + Assert.Equal(ChatRole.User, leafClientMessage.Role); + Assert.Equal("Test message", leafClientMessage.Text); + } + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task UseMcpClient_AuthorizationTokenHeaderFlowsCorrectly(bool streaming) + { + // Arrange + const string testToken = "test-bearer-token-12345"; + bool authReceivedForInitialize = false; + bool authReceivedForNotificationsInitialized = false; + bool authReceivedForToolsList = false; + bool authReceivedForToolsCall = false; + + await using var _ = await StartServerAsync( + configureApp: app => + { + app.Use(async (context, next) => + { + if (context.Request.Method == "POST" && + context.Request.Headers.TryGetValue("Authorization", out var authHeader)) + { + Assert.Equal($"Bearer {testToken}", authHeader.ToString()); + + context.Request.EnableBuffering(); + JsonRpcRequest? rpcRequest = await JsonSerializer.DeserializeAsync( + context.Request.Body, + McpJsonUtilities.DefaultOptions, + context.RequestAborted); + context.Request.Body.Position = 0; + Assert.NotNull(rpcRequest); + + switch (rpcRequest.Method) + { + case "initialize": + authReceivedForInitialize = true; + break; + case "notifications/initialized": + authReceivedForNotificationsInitialized = true; + break; + case "tools/list": + authReceivedForToolsList = true; + break; + case "tools/call": + authReceivedForToolsCall = true; + break; + } + } + await next(); + }); + }); + + using IChatClient sut = CreateTestChatClient(out var leafClientState); + var mcpTool = new HostedMcpServerTool("serverName", HttpClient.BaseAddress!) + { + AuthorizationToken = testToken, + ApprovalMode = HostedMcpServerToolApprovalMode.NeverRequire + }; + var options = new ChatOptions + { + Tools = [mcpTool] + }; + + // Act + var response = streaming ? + await sut.GetStreamingResponseAsync("Test message", options, TestContext.Current.CancellationToken).ToChatResponseAsync(TestContext.Current.CancellationToken) : + await sut.GetResponseAsync("Test message", options, TestContext.Current.CancellationToken); + + // Assert + AssertResponseWithInvocation(response); + AssertLeafClientMessagesWithInvocation(leafClientState.CapturedMessages); + Assert.True(authReceivedForInitialize, "Authorization header was not captured in initial request"); + Assert.True(authReceivedForNotificationsInitialized, "Authorization header was not captured in notifications/initialized request"); + Assert.True(authReceivedForToolsList, "Authorization header was not captured in tools/list request"); + Assert.True(authReceivedForToolsCall, "Authorization header was not captured in tools/call request"); + + Assert.NotNull(leafClientState.CapturedOptions); + Assert.NotNull(leafClientState.CapturedOptions.Tools); + var toolNames = leafClientState.CapturedOptions.Tools.Select(t => t.Name).ToList(); + Assert.Equal(3, toolNames.Count); + Assert.Contains("echo", toolNames); + Assert.Contains("echoSessionId", toolNames); + Assert.Contains("sampleLLM", toolNames); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task UseMcpClient_CachesClientForSameServerAddress(bool streaming) + { + // Arrange + int initializeCallCount = 0; + await using var _ = await StartServerAsync(configureApp: app => + { + app.Use(async (context, next) => + { + if (context.Request.Method == "POST") + { + context.Request.EnableBuffering(); + var rpcRequest = await JsonSerializer.DeserializeAsync( + context.Request.Body, + McpJsonUtilities.DefaultOptions, + context.RequestAborted); + context.Request.Body.Position = 0; + + if (rpcRequest?.Method == "initialize") + { + initializeCallCount++; + } + } + await next(); + }); + }); + + using IChatClient sut = CreateTestChatClient(out var leafClientState); + var mcpTool = new HostedMcpServerTool("serverName", HttpClient.BaseAddress!) + { + ApprovalMode = HostedMcpServerToolApprovalMode.NeverRequire + }; + var options = new ChatOptions { Tools = [mcpTool] }; + + // Act + var response = streaming ? + await sut.GetStreamingResponseAsync("Test message", options, TestContext.Current.CancellationToken).ToChatResponseAsync(TestContext.Current.CancellationToken) : + await sut.GetResponseAsync("Test message", options, TestContext.Current.CancellationToken); + + // Assert + AssertResponseWithInvocation(response); + AssertLeafClientMessagesWithInvocation(leafClientState.CapturedMessages); + Assert.NotNull(leafClientState.CapturedOptions); + Assert.NotNull(leafClientState.CapturedOptions.Tools); + var firstCallToolCount = leafClientState.CapturedOptions.Tools.Count; + Assert.Equal(3, firstCallToolCount); + var toolNames = leafClientState.CapturedOptions.Tools.Select(t => t.Name).ToList(); + Assert.Contains("echo", toolNames); + Assert.Contains("echoSessionId", toolNames); + Assert.Contains("sampleLLM", toolNames); + Assert.Equal(1, initializeCallCount); + + // Arrange + leafClientState.Clear(); + + // Act + var secondResponse = streaming ? + await sut.GetStreamingResponseAsync("Test message", options, TestContext.Current.CancellationToken).ToChatResponseAsync(TestContext.Current.CancellationToken) : + await sut.GetResponseAsync("Test message", options, TestContext.Current.CancellationToken); + + // Assert + AssertResponseWithInvocation(secondResponse); + AssertLeafClientMessagesWithInvocation(leafClientState.CapturedMessages); + Assert.NotNull(leafClientState.CapturedOptions); + Assert.NotNull(leafClientState.CapturedOptions.Tools); + var secondCallToolCount = leafClientState.CapturedOptions.Tools.Count; + Assert.Equal(3, secondCallToolCount); + Assert.Equal(firstCallToolCount, secondCallToolCount); + toolNames = leafClientState.CapturedOptions.Tools.Select(t => t.Name).ToList(); + Assert.Contains("echo", toolNames); + Assert.Contains("echoSessionId", toolNames); + Assert.Contains("sampleLLM", toolNames); + Assert.True(initializeCallCount == 1, "Initialize should not be called more than once because the MCP client is cached."); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task UseMcpClient_RetriesWhenSessionRevokedByServer(bool streaming) + { + // Arrange + string? firstSessionId = null; + string? secondSessionId = null; + + await using var app = await StartServerAsync( + configureApp: app => + { + app.Use(async (context, next) => + { + if (context.Request.Method == "POST") + { + context.Request.EnableBuffering(); + var rpcRequest = await JsonSerializer.DeserializeAsync( + context.Request.Body, + McpJsonUtilities.DefaultOptions); + context.Request.Body.Position = 0; + + if (rpcRequest?.Method == "tools/call" && context.Request.Headers.TryGetValue("Mcp-Session-Id", out var sessionIdHeader)) + { + var sessionId = sessionIdHeader.ToString(); + + if (firstSessionId == null) + { + // First tool call - capture session and return 404 to revoke it + firstSessionId = sessionId; + context.Response.StatusCode = StatusCodes.Status404NotFound; + return; + } + else + { + // Second tool call - capture session and let it succeed + secondSessionId = sessionId; + } + } + } + await next(); + }); + }); + + using IChatClient sut = CreateTestChatClient(out var leafClientState); + var mcpTool = new HostedMcpServerTool("serverName", HttpClient.BaseAddress!) + { + ApprovalMode = HostedMcpServerToolApprovalMode.NeverRequire + }; + var options = new ChatOptions { Tools = [mcpTool] }; + + // Act + var response = streaming ? + await sut.GetStreamingResponseAsync("Test message", options, TestContext.Current.CancellationToken).ToChatResponseAsync(TestContext.Current.CancellationToken) : + await sut.GetResponseAsync("Test message", options, TestContext.Current.CancellationToken); + + // Assert + Assert.NotNull(firstSessionId); + Assert.NotNull(secondSessionId); + Assert.NotEqual(firstSessionId, secondSessionId); + AssertResponseWithInvocation(response); + AssertLeafClientMessagesWithInvocation(leafClientState.CapturedMessages); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task UseMcpClient_RetriesOnServerError(bool streaming) + { + int toolCallCount = 0; + await using var app = await StartServerAsync(configureApp: app => + { + app.Use(async (context, next) => + { + if (context.Request.Method == "POST") + { + context.Request.EnableBuffering(); + var rpcRequest = await JsonSerializer.DeserializeAsync( + context.Request.Body, + McpJsonUtilities.DefaultOptions, + context.RequestAborted); + context.Request.Body.Position = 0; + + if (rpcRequest?.Method == "tools/call" && ++toolCallCount == 1) + { + throw new Exception("Simulated server error."); + } + } + await next(); + }); + }); + + using IChatClient sut = CreateTestChatClient(out var leafClientState); + + var mcpTool = new HostedMcpServerTool("serverName", HttpClient.BaseAddress!) + { + ApprovalMode = HostedMcpServerToolApprovalMode.NeverRequire + }; + var options = new ChatOptions { Tools = [mcpTool] }; + + // Act + var response = streaming ? + await sut.GetStreamingResponseAsync("Test message", options, TestContext.Current.CancellationToken).ToChatResponseAsync(TestContext.Current.CancellationToken) : + await sut.GetResponseAsync("Test message", options, TestContext.Current.CancellationToken); + + // Assert + Assert.Equal(2, toolCallCount); + AssertResponseWithInvocation(response); + AssertLeafClientMessagesWithInvocation(leafClientState.CapturedMessages); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task UseMcpClient_ConfigureTransportOptions_CallbackIsInvoked(bool streaming) + { + // Arrange + HostedMcpServerTool? capturedTool = null; + HttpClientTransportOptions? capturedTransportOptions = null; + await using var _ = await StartServerAsync(); + + using IChatClient sut = CreateTestChatClient(out var leafClientState, (tool, transportOptions) => + { + capturedTool = tool; + capturedTransportOptions = transportOptions; + }); + + var mcpTool = new HostedMcpServerTool("serverName", HttpClient.BaseAddress!) + { + ApprovalMode = HostedMcpServerToolApprovalMode.NeverRequire, + AuthorizationToken = "test-auth-token-123" + }; + var options = new ChatOptions { Tools = [mcpTool] }; + + // Act + var response = streaming ? + await sut.GetStreamingResponseAsync("Test message", options, TestContext.Current.CancellationToken).ToChatResponseAsync(TestContext.Current.CancellationToken) : + await sut.GetResponseAsync("Test message", options, TestContext.Current.CancellationToken); + + // Assert + AssertResponseWithInvocation(response); + AssertLeafClientMessagesWithInvocation(leafClientState.CapturedMessages); + + Assert.NotNull(capturedTool); + Assert.Equal("serverName", capturedTool.ServerName); + Assert.Equal(HttpClient.BaseAddress!.ToString(), capturedTool.ServerAddress); + Assert.Null(capturedTool.ServerDescription); + Assert.Null(capturedTool.AuthorizationToken); + Assert.Null(capturedTool.AllowedTools); + Assert.Null(capturedTool.ApprovalMode); + + Assert.NotNull(capturedTransportOptions); + Assert.Equal(HttpClient.BaseAddress, capturedTransportOptions.Endpoint); + Assert.Equal("serverName", capturedTransportOptions.Name); + Assert.Equal("Bearer test-auth-token-123", capturedTransportOptions.AdditionalHeaders!["Authorization"]); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task UseMcpClient_ThrowsInvalidOperationException_WhenServerAddressIsInvalid(bool streaming) + { + // Arrange + await using var _ = await StartServerAsync(); + using IChatClient sut = CreateTestChatClient(out var leafClientState); + var mcpTool = new HostedMcpServerTool("serverNameConnector", "test-connector-123"); + var options = new ChatOptions { Tools = [mcpTool] }; + + // Act & Assert + var exception = await Assert.ThrowsAsync(() => streaming ? + sut.GetStreamingResponseAsync("Test message", options, TestContext.Current.CancellationToken).ToChatResponseAsync(TestContext.Current.CancellationToken) : + sut.GetResponseAsync("Test message", options, TestContext.Current.CancellationToken)); + Assert.Contains("test-connector-123", exception.Message); + } +}