From ec5383c23aa2241180593fafbfe5e3befc555bed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Tue, 18 Mar 2025 13:23:35 +0100 Subject: [PATCH 001/222] Add CI GitHub Action and rename snapshot publish workflow --- .github/workflows/ci.yml | 22 +++++++++++++++++++ ...s-integration.yml => publish-snapshot.yml} | 2 +- 2 files changed, 23 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/ci.yml rename .github/workflows/{continuous-integration.yml => publish-snapshot.yml} (98%) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 000000000..7c73d9f38 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,22 @@ +name: CI + +on: + pull_request: {} + +jobs: + build: + name: Build branch + runs-on: ubuntu-latest + steps: + - name: Checkout source code + uses: actions/checkout@v4 + + - name: Set up JDK 17 + uses: actions/setup-java@v4 + with: + java-version: '17' + distribution: 'temurin' + cache: 'maven' + + - name: Build + run: mvn verify diff --git a/.github/workflows/continuous-integration.yml b/.github/workflows/publish-snapshot.yml similarity index 98% rename from .github/workflows/continuous-integration.yml rename to .github/workflows/publish-snapshot.yml index e0939f087..5d9b4aa39 100644 --- a/.github/workflows/continuous-integration.yml +++ b/.github/workflows/publish-snapshot.yml @@ -1,4 +1,4 @@ -name: CI/CD build +name: Publish Snapshot on: push: From 6ef1b580347cc2dc343e63a60fe194031250a947 Mon Sep 17 00:00:00 2001 From: Christian Tzolov <1351573+tzolov@users.noreply.github.com> Date: Tue, 18 Mar 2025 19:18:54 +0100 Subject: [PATCH 002/222] Update README.md Fix the build status badge --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index caa6bf0c0..ca87736cd 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ # MCP Java SDK -[![Build Status](https://github.com/modelcontextprotocol/java-sdk/actions/workflows/continuous-integration.yml/badge.svg)](https://github.com/modelcontextprotocol/java-sdk/actions/workflows/continuous-integration.yml) +[![Build Status](https://github.com/modelcontextprotocol/java-sdk/actions/workflows/publish-snapshot.yml/badge.svg)](https://github.com/modelcontextprotocol/java-sdk/actions/workflows/publish-snapshot.yml) A set of projects that provide Java SDK integration for the [Model Context Protocol](https://modelcontextprotocol.org/docs/concepts/architecture). This SDK enables Java applications to interact with AI models and tools through a standardized interface, supporting both synchronous and asynchronous communication patterns. From 1a673b35672921c541e4feccf3d7ac4cd60c34ec Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Tue, 18 Mar 2025 18:57:22 +0100 Subject: [PATCH 003/222] refactor: improve MCP client timeout handling and reactive testing - Add configurable initialization timeout separate from request timeout - Rename ServletSse* test classes to HttpSse* for better naming consistency - Replace direct .block() calls with StepVerifier for better reactive testing - Change ping() method to return Mono instead of Mono - Improve error handling and reactive programming patterns throughout tests - Chain reactive operations for cleaner test flow Signed-off-by: Christian Tzolov --- .../client/WebFluxSseMcpAsyncClientTests.java | 7 - .../client/WebFluxSseMcpSyncClientTests.java | 7 - .../client/AbstractMcpAsyncClientTests.java | 255 ++++++++--------- .../client/AbstractMcpSyncClientTests.java | 17 +- .../client/McpAsyncClient.java | 19 +- .../client/McpClient.java | 33 ++- .../client/McpSyncClient.java | 7 +- .../client/AbstractMcpAsyncClientTests.java | 257 +++++++++--------- .../client/AbstractMcpSyncClientTests.java | 15 +- ...s.java => HttpSseMcpAsyncClientTests.java} | 7 +- ...ts.java => HttpSseMcpSyncClientTests.java} | 7 +- .../client/StdioMcpSyncClientTests.java | 2 +- 12 files changed, 339 insertions(+), 294 deletions(-) rename mcp/src/test/java/io/modelcontextprotocol/client/{ServletSseMcpAsyncClientTests.java => HttpSseMcpAsyncClientTests.java} (89%) rename mcp/src/test/java/io/modelcontextprotocol/client/{ServletSseMcpSyncClientTests.java => HttpSseMcpSyncClientTests.java} (89%) diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java index 6cd74631e..021ce4654 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java @@ -4,8 +4,6 @@ package io.modelcontextprotocol.client; -import java.time.Duration; - import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; import io.modelcontextprotocol.spec.ClientMcpTransport; import org.junit.jupiter.api.Timeout; @@ -48,9 +46,4 @@ public void onClose() { container.stop(); } - @Override - protected Duration getTimeoutDuration() { - return Duration.ofMillis(300); - } - } diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java index 6b980da41..20eeb1d59 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java @@ -4,8 +4,6 @@ package io.modelcontextprotocol.client; -import java.time.Duration; - import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; import io.modelcontextprotocol.spec.ClientMcpTransport; import org.junit.jupiter.api.Timeout; @@ -48,9 +46,4 @@ protected void onClose() { container.stop(); } - @Override - protected Duration getTimeoutDuration() { - return Duration.ofMillis(300); - } - } diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java index cdcba4d1c..17cc99608 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java @@ -58,8 +58,12 @@ protected void onStart() { protected void onClose() { } - protected Duration getTimeoutDuration() { - return Duration.ofSeconds(2); + protected Duration getRequestTimeout() { + return Duration.ofSeconds(10); + } + + protected Duration getInitializationTimeout() { + return Duration.ofSeconds(1); } @BeforeEach @@ -69,7 +73,8 @@ void setUp() { assertThatCode(() -> { mcpAsyncClient = McpClient.async(mcpTransport) - .requestTimeout(getTimeoutDuration()) + .requestTimeout(getRequestTimeout()) + .initializationTimeout(getInitializationTimeout()) .capabilities(ClientCapabilities.builder().roots(true).build()) .build(); }).doesNotThrowAnyException(); @@ -78,8 +83,7 @@ void setUp() { @AfterEach void tearDown() { if (mcpAsyncClient != null) { - assertThatCode(() -> mcpAsyncClient.closeGracefully().block(Duration.ofSeconds(10))) - .doesNotThrowAnyException(); + StepVerifier.create(mcpAsyncClient.closeGracefully()).verifyComplete(); } onClose(); } @@ -96,87 +100,93 @@ void testConstructorWithInvalidArguments() { @Test void testListToolsWithoutInitialization() { - assertThatThrownBy(() -> mcpAsyncClient.listTools(null).block()).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before listing tools"); + StepVerifier.create(mcpAsyncClient.listTools(null)).expectErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class) + .hasMessage("Client must be initialized before listing tools"); + }).verify(); } @Test void testListTools() { - mcpAsyncClient.initialize().block(Duration.ofSeconds(10)); - - StepVerifier.create(mcpAsyncClient.listTools(null)).consumeNextWith(result -> { - assertThat(result.tools()).isNotNull().isNotEmpty(); + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listTools(null))) + .consumeNextWith(result -> { + assertThat(result.tools()).isNotNull().isNotEmpty(); - Tool firstTool = result.tools().get(0); - assertThat(firstTool.name()).isNotNull(); - assertThat(firstTool.description()).isNotNull(); - }).verifyComplete(); + Tool firstTool = result.tools().get(0); + assertThat(firstTool.name()).isNotNull(); + assertThat(firstTool.description()).isNotNull(); + }) + .verifyComplete(); } @Test void testPingWithoutInitialization() { - assertThatThrownBy(() -> mcpAsyncClient.ping().block()).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before pinging the server"); + StepVerifier.create(mcpAsyncClient.ping()) + .expectErrorMatches(error -> error instanceof McpError + && error.getMessage().equals("Client must be initialized before pinging the server")) + .verify(); } @Test void testPing() { - mcpAsyncClient.initialize().block(Duration.ofSeconds(10)); - assertThatCode(() -> mcpAsyncClient.ping().block()).doesNotThrowAnyException(); + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.ping())).verifyComplete(); } @Test void testCallToolWithoutInitialization() { CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", ECHO_TEST_MESSAGE)); - assertThatThrownBy(() -> mcpAsyncClient.callTool(callToolRequest).block()).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before calling tools"); + StepVerifier.create(mcpAsyncClient.callTool(callToolRequest)) + .expectErrorMatches(error -> error instanceof McpError + && error.getMessage().equals("Client must be initialized before calling tools")) + .verify(); } @Test void testCallTool() { - mcpAsyncClient.initialize().block(Duration.ofSeconds(10)); - CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", ECHO_TEST_MESSAGE)); - StepVerifier.create(mcpAsyncClient.callTool(callToolRequest)).consumeNextWith(callToolResult -> { - assertThat(callToolResult).isNotNull().satisfies(result -> { - assertThat(result.content()).isNotNull(); - assertThat(result.isError()).isNull(); - }); - }).verifyComplete(); + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.callTool(callToolRequest))) + .consumeNextWith(callToolResult -> { + assertThat(callToolResult).isNotNull(); + assertThat(callToolResult.content()).isNotNull(); + assertThat(callToolResult.isError()).isNull(); + }) + .verifyComplete(); } @Test void testCallToolWithInvalidTool() { - mcpAsyncClient.initialize().block(Duration.ofSeconds(10)); - CallToolRequest invalidRequest = new CallToolRequest("nonexistent_tool", Map.of("message", ECHO_TEST_MESSAGE)); - assertThatThrownBy(() -> mcpAsyncClient.callTool(invalidRequest).block()).isInstanceOf(Exception.class); + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.callTool(invalidRequest))) + .expectError(Exception.class) + .verify(); } @Test void testListResourcesWithoutInitialization() { - assertThatThrownBy(() -> mcpAsyncClient.listResources(null).block()).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before listing resources"); + StepVerifier.create(mcpAsyncClient.listResources(null)) + .expectErrorMatches(error -> error instanceof McpError + && error.getMessage().equals("Client must be initialized before listing resources")) + .verify(); } @Test void testListResources() { - mcpAsyncClient.initialize().block(Duration.ofSeconds(10)); - - StepVerifier.create(mcpAsyncClient.listResources(null)).consumeNextWith(resources -> { - assertThat(resources).isNotNull().satisfies(result -> { - assertThat(result.resources()).isNotNull(); - - if (!result.resources().isEmpty()) { - Resource firstResource = result.resources().get(0); - assertThat(firstResource.uri()).isNotNull(); - assertThat(firstResource.name()).isNotNull(); - } - }); - }).verifyComplete(); + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResources(null))) + .consumeNextWith(resources -> { + assertThat(resources).isNotNull().satisfies(result -> { + assertThat(result.resources()).isNotNull(); + + if (!result.resources().isEmpty()) { + Resource firstResource = result.resources().get(0); + assertThat(firstResource.uri()).isNotNull(); + assertThat(firstResource.name()).isNotNull(); + } + }); + }) + .verifyComplete(); } @Test @@ -186,40 +196,44 @@ void testMcpAsyncClientState() { @Test void testListPromptsWithoutInitialization() { - assertThatThrownBy(() -> mcpAsyncClient.listPrompts(null).block()).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before listing prompts"); + StepVerifier.create(mcpAsyncClient.listPrompts(null)) + .expectErrorMatches(error -> error instanceof McpError + && error.getMessage().equals("Client must be initialized before listing prompts")) + .verify(); } @Test void testListPrompts() { - mcpAsyncClient.initialize().block(Duration.ofSeconds(10)); - - StepVerifier.create(mcpAsyncClient.listPrompts(null)).consumeNextWith(prompts -> { - assertThat(prompts).isNotNull().satisfies(result -> { - assertThat(result.prompts()).isNotNull(); - - if (!result.prompts().isEmpty()) { - Prompt firstPrompt = result.prompts().get(0); - assertThat(firstPrompt.name()).isNotNull(); - assertThat(firstPrompt.description()).isNotNull(); - } - }); - }).verifyComplete(); + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listPrompts(null))) + .consumeNextWith(prompts -> { + assertThat(prompts).isNotNull().satisfies(result -> { + assertThat(result.prompts()).isNotNull(); + + if (!result.prompts().isEmpty()) { + Prompt firstPrompt = result.prompts().get(0); + assertThat(firstPrompt.name()).isNotNull(); + assertThat(firstPrompt.description()).isNotNull(); + } + }); + }) + .verifyComplete(); } @Test void testGetPromptWithoutInitialization() { GetPromptRequest request = new GetPromptRequest("simple_prompt", Map.of()); - assertThatThrownBy(() -> mcpAsyncClient.getPrompt(request).block()).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before getting prompts"); + StepVerifier.create(mcpAsyncClient.getPrompt(request)) + .expectErrorMatches(error -> error instanceof McpError + && error.getMessage().equals("Client must be initialized before getting prompts")) + .verify(); } @Test void testGetPrompt() { - mcpAsyncClient.initialize().block(Duration.ofSeconds(10)); + GetPromptRequest request = new GetPromptRequest("simple_prompt", Map.of()); - StepVerifier.create(mcpAsyncClient.getPrompt(new GetPromptRequest("simple_prompt", Map.of()))) + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.getPrompt(request))) .consumeNextWith(prompt -> { assertThat(prompt).isNotNull().satisfies(result -> { assertThat(result.messages()).isNotEmpty(); @@ -231,15 +245,16 @@ void testGetPrompt() { @Test void testRootsListChangedWithoutInitialization() { - assertThatThrownBy(() -> mcpAsyncClient.rootsListChangedNotification().block()).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before sending roots list changed notification"); + StepVerifier.create(mcpAsyncClient.rootsListChangedNotification()) + .expectErrorMatches(error -> error instanceof McpError && error.getMessage() + .equals("Client must be initialized before sending roots list changed notification")) + .verify(); } @Test void testRootsListChanged() { - mcpAsyncClient.initialize().block(Duration.ofSeconds(10)); - - assertThatCode(() -> mcpAsyncClient.rootsListChangedNotification().block()).doesNotThrowAnyException(); + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.rootsListChangedNotification())) + .verifyComplete(); } @Test @@ -247,39 +262,39 @@ void testInitializeWithRootsListProviders() { var transport = createMcpTransport(); var client = McpClient.async(transport) - .requestTimeout(getTimeoutDuration()) + .requestTimeout(getRequestTimeout()) .roots(new Root("file:///test/path", "test-root")) .build(); - assertThatCode(() -> client.initialize().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - - assertThatCode(() -> client.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + StepVerifier.create(client.initialize().then(client.closeGracefully())).verifyComplete(); } @Test void testAddRoot() { Root newRoot = new Root("file:///new/test/path", "new-test-root"); - assertThatCode(() -> mcpAsyncClient.addRoot(newRoot).block()).doesNotThrowAnyException(); + + StepVerifier.create(mcpAsyncClient.addRoot(newRoot)).verifyComplete(); } @Test void testAddRootWithNullValue() { - assertThatThrownBy(() -> mcpAsyncClient.addRoot(null).block()).hasMessageContaining("Root must not be null"); + StepVerifier.create(mcpAsyncClient.addRoot(null)) + .expectErrorMatches(error -> error.getMessage().contains("Root must not be null")) + .verify(); } @Test void testRemoveRoot() { Root root = new Root("file:///test/path/to/remove", "root-to-remove"); - assertThatCode(() -> { - mcpAsyncClient.addRoot(root).block(); - mcpAsyncClient.removeRoot(root.uri()).block(); - }).doesNotThrowAnyException(); + + StepVerifier.create(mcpAsyncClient.addRoot(root).then(mcpAsyncClient.removeRoot(root.uri()))).verifyComplete(); } @Test void testRemoveNonExistentRoot() { - assertThatThrownBy(() -> mcpAsyncClient.removeRoot("nonexistent-uri").block()) - .hasMessageContaining("Root with uri 'nonexistent-uri' not found"); + StepVerifier.create(mcpAsyncClient.removeRoot("nonexistent-uri")) + .expectErrorMatches(error -> error.getMessage().contains("Root with uri 'nonexistent-uri' not found")) + .verify(); } @Test @@ -298,18 +313,20 @@ void testReadResource() { @Test void testListResourceTemplatesWithoutInitialization() { - assertThatThrownBy(() -> mcpAsyncClient.listResourceTemplates().block()).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before listing resource templates"); + StepVerifier.create(mcpAsyncClient.listResourceTemplates()) + .expectErrorMatches(error -> error instanceof McpError + && error.getMessage().equals("Client must be initialized before listing resource templates")) + .verify(); } @Test void testListResourceTemplates() { - mcpAsyncClient.initialize().block(Duration.ofSeconds(10)); - - StepVerifier.create(mcpAsyncClient.listResourceTemplates()).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.resourceTemplates()).isNotNull(); - }).verifyComplete(); + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResourceTemplates())) + .consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.resourceTemplates()).isNotNull(); + }) + .verifyComplete(); } // @Test @@ -337,16 +354,13 @@ void testNotificationHandlers() { var transport = createMcpTransport(); var client = McpClient.async(transport) - .requestTimeout(getTimeoutDuration()) + .requestTimeout(getRequestTimeout()) .toolsChangeConsumer(tools -> Mono.fromRunnable(() -> toolsNotificationReceived.set(true))) .resourcesChangeConsumer(resources -> Mono.fromRunnable(() -> resourcesNotificationReceived.set(true))) .promptsChangeConsumer(prompts -> Mono.fromRunnable(() -> promptsNotificationReceived.set(true))) .build(); - assertThatCode(() -> { - client.initialize().block(); - client.closeGracefully().block(); - }).doesNotThrowAnyException(); + StepVerifier.create(client.initialize().then(client.closeGracefully())).verifyComplete(); } @Test @@ -356,15 +370,12 @@ void testInitializeWithSamplingCapability() { var capabilities = ClientCapabilities.builder().sampling().build(); var client = McpClient.async(transport) - .requestTimeout(getTimeoutDuration()) + .requestTimeout(getRequestTimeout()) .capabilities(capabilities) .sampling(request -> Mono.just(CreateMessageResult.builder().message("test").model("test-model").build())) .build(); - assertThatCode(() -> { - client.initialize().block(Duration.ofSeconds(10)); - client.closeGracefully().block(Duration.ofSeconds(10)); - }).doesNotThrowAnyException(); + StepVerifier.create(client.initialize().then(client.closeGracefully())).verifyComplete(); } @Test @@ -380,17 +391,17 @@ void testInitializeWithAllCapabilities() { Function> samplingHandler = request -> Mono .just(CreateMessageResult.builder().message("test").model("test-model").build()); var client = McpClient.async(transport) - .requestTimeout(getTimeoutDuration()) + .requestTimeout(getRequestTimeout()) .capabilities(capabilities) .sampling(samplingHandler) .build(); - assertThatCode(() -> { - var result = client.initialize().block(Duration.ofSeconds(10)); + StepVerifier.create(client.initialize()).consumeNextWith(result -> { assertThat(result).isNotNull(); assertThat(result.capabilities()).isNotNull(); - client.closeGracefully().block(Duration.ofSeconds(10)); - }).doesNotThrowAnyException(); + }).verifyComplete(); + + StepVerifier.create(client.closeGracefully()).verifyComplete(); } // --------------------------------------- @@ -399,19 +410,23 @@ void testInitializeWithAllCapabilities() { @Test void testLoggingLevelsWithoutInitialization() { - assertThatThrownBy(() -> mcpAsyncClient.setLoggingLevel(McpSchema.LoggingLevel.DEBUG).block()) - .isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before setting logging level"); + StepVerifier.create(mcpAsyncClient.setLoggingLevel(McpSchema.LoggingLevel.DEBUG)) + .expectErrorMatches(error -> error instanceof McpError + && error.getMessage().equals("Client must be initialized before setting logging level")) + .verify(); } @Test void testLoggingLevels() { - mcpAsyncClient.initialize().block(Duration.ofSeconds(10)); + Mono testAllLevels = mcpAsyncClient.initialize().then(Mono.defer(() -> { + Mono chain = Mono.empty(); + for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { + chain = chain.then(mcpAsyncClient.setLoggingLevel(level)); + } + return chain; + })); - // Test all logging levels - for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { - StepVerifier.create(mcpAsyncClient.setLoggingLevel(level)).verifyComplete(); - } + StepVerifier.create(testAllLevels).verifyComplete(); } @Test @@ -420,20 +435,18 @@ void testLoggingConsumer() { var transport = createMcpTransport(); var client = McpClient.async(transport) - .requestTimeout(getTimeoutDuration()) + .requestTimeout(getRequestTimeout()) .loggingConsumer(notification -> Mono.fromRunnable(() -> logReceived.set(true))) .build(); - assertThatCode(() -> { - client.initialize().block(Duration.ofSeconds(10)); - client.closeGracefully().block(Duration.ofSeconds(10)); - }).doesNotThrowAnyException(); + StepVerifier.create(client.initialize().then(client.closeGracefully())).verifyComplete(); } @Test void testLoggingWithNullNotification() { - assertThatThrownBy(() -> mcpAsyncClient.setLoggingLevel(null).block()) - .hasMessageContaining("Logging level must not be null"); + StepVerifier.create(mcpAsyncClient.setLoggingLevel(null)) + .expectErrorMatches(error -> error.getMessage().contains("Logging level must not be null")) + .verify(); } } diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java index aeed06cbf..ee43a572e 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java @@ -52,8 +52,12 @@ public abstract class AbstractMcpSyncClientTests { abstract protected void onClose(); - protected Duration getTimeoutDuration() { - return Duration.ofSeconds(2); + protected Duration getRequestTimeout() { + return Duration.ofSeconds(10); + } + + protected Duration getInitializationTimeout() { + return Duration.ofSeconds(1); } @BeforeEach @@ -63,7 +67,8 @@ void setUp() { assertThatCode(() -> { mcpSyncClient = McpClient.sync(mcpTransport) - .requestTimeout(getTimeoutDuration()) + .requestTimeout(getRequestTimeout()) + .initializationTimeout(getInitializationTimeout()) .capabilities(ClientCapabilities.builder().roots(true).build()) .build(); }).doesNotThrowAnyException(); @@ -215,7 +220,7 @@ void testInitializeWithRootsListProviders() { var transport = createMcpTransport(); var client = McpClient.sync(transport) - .requestTimeout(getTimeoutDuration()) + .requestTimeout(getRequestTimeout()) .roots(new Root("file:///test/path", "test-root")) .build(); @@ -313,7 +318,7 @@ void testNotificationHandlers() { var transport = createMcpTransport(); var client = McpClient.sync(transport) - .requestTimeout(getTimeoutDuration()) + .requestTimeout(getRequestTimeout()) .toolsChangeConsumer(tools -> toolsNotificationReceived.set(true)) .resourcesChangeConsumer(resources -> resourcesNotificationReceived.set(true)) .promptsChangeConsumer(prompts -> promptsNotificationReceived.set(true)) @@ -351,7 +356,7 @@ void testLoggingConsumer() { var transport = createMcpTransport(); var client = McpClient.sync(transport) - .requestTimeout(getTimeoutDuration()) + .requestTimeout(getRequestTimeout()) .loggingConsumer(notification -> logReceived.set(true)) .build(); diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java index b301aa93a..4c5fd02ca 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java @@ -88,7 +88,6 @@ public class McpAsyncClient { /** * The max timeout to await for the client-server connection to be initialized. - * Usually x2 the request timeout. // TODO should we make it configurable? */ private final Duration initializationTimeout; @@ -151,18 +150,21 @@ public class McpAsyncClient { * timeout. * @param transport the transport to use. * @param requestTimeout the session request-response timeout. + * @param initializationTimeout the max timeout to await for the client-server * @param features the MCP Client supported features. */ - McpAsyncClient(ClientMcpTransport transport, Duration requestTimeout, McpClientFeatures.Async features) { + McpAsyncClient(ClientMcpTransport transport, Duration requestTimeout, Duration initializationTimeout, + McpClientFeatures.Async features) { Assert.notNull(transport, "Transport must not be null"); Assert.notNull(requestTimeout, "Request timeout must not be null"); + Assert.notNull(initializationTimeout, "Initialization timeout must not be null"); this.clientInfo = features.clientInfo(); this.clientCapabilities = features.clientCapabilities(); this.transport = transport; this.roots = new ConcurrentHashMap<>(features.roots()); - this.initializationTimeout = requestTimeout.multipliedBy(2); + this.initializationTimeout = initializationTimeout; // Request Handlers Map> requestHandlers = new HashMap<>(); @@ -367,12 +369,13 @@ private Mono withInitializationCheck(String actionName, /** * Sends a ping request to the server. - * @return A Mono that completes with the server's ping response + * @return A Mono that completes when the server responds to the ping */ - public Mono ping() { + public Mono ping() { return this.withInitializationCheck("pinging the server", initializedResult -> this.mcpSession .sendRequest(McpSchema.METHOD_PING, null, new TypeReference() { - })); + }) + .then()); } // -------------------------- @@ -771,7 +774,9 @@ private NotificationHandler asyncLoggingNotificationHandler( * @see McpSchema.LoggingLevel */ public Mono setLoggingLevel(LoggingLevel loggingLevel) { - Assert.notNull(loggingLevel, "Logging level must not be null"); + if (loggingLevel == null) { + return Mono.error(new McpError("Logging level must not be null")); + } return this.withInitializationCheck("setting logging level", initializedResult -> { String levelName = this.transport.unmarshalFrom(loggingLevel, new TypeReference() { diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java index 7ab01b70c..fa2690dc3 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java @@ -157,6 +157,8 @@ class SyncSpec { private Duration requestTimeout = Duration.ofSeconds(20); // Default timeout + private Duration initializationTimeout = Duration.ofSeconds(20); + private ClientCapabilities capabilities; private Implementation clientInfo = new Implementation("Java SDK MCP Client", "1.0.0"); @@ -193,6 +195,18 @@ public SyncSpec requestTimeout(Duration requestTimeout) { return this; } + /** + * @param initializationTimeout The duration to wait for the initializaiton + * lifecycle step to complete. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if initializationTimeout is null + */ + public SyncSpec initializationTimeout(Duration initializationTimeout) { + Assert.notNull(initializationTimeout, "Initialization timeout must not be null"); + this.initializationTimeout = initializationTimeout; + return this; + } + /** * Sets the client capabilities that will be advertised to the server during * connection initialization. Capabilities define what features the client @@ -354,7 +368,8 @@ public McpSyncClient build() { McpClientFeatures.Async asyncFeatures = McpClientFeatures.Async.fromSync(syncFeatures); - return new McpSyncClient(new McpAsyncClient(transport, this.requestTimeout, asyncFeatures)); + return new McpSyncClient( + new McpAsyncClient(transport, this.requestTimeout, this.initializationTimeout, asyncFeatures)); } } @@ -381,6 +396,8 @@ class AsyncSpec { private Duration requestTimeout = Duration.ofSeconds(20); // Default timeout + private Duration initializationTimeout = Duration.ofSeconds(20); + private ClientCapabilities capabilities; private Implementation clientInfo = new Implementation("Spring AI MCP Client", "0.3.1"); @@ -417,6 +434,18 @@ public AsyncSpec requestTimeout(Duration requestTimeout) { return this; } + /** + * @param initializationTimeout The duration to wait for the initializaiton + * lifecycle step to complete. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if initializationTimeout is null + */ + public AsyncSpec initializationTimeout(Duration initializationTimeout) { + Assert.notNull(initializationTimeout, "Initialization timeout must not be null"); + this.initializationTimeout = initializationTimeout; + return this; + } + /** * Sets the client capabilities that will be advertised to the server during * connection initialization. Capabilities define what features the client @@ -574,7 +603,7 @@ public AsyncSpec loggingConsumers( * @return a new instance of {@link McpAsyncClient}. */ public McpAsyncClient build() { - return new McpAsyncClient(this.transport, this.requestTimeout, + return new McpAsyncClient(this.transport, this.requestTimeout, this.initializationTimeout, new McpClientFeatures.Async(this.clientInfo, this.capabilities, this.roots, this.toolsChangeConsumers, this.resourcesChangeConsumers, this.promptsChangeConsumers, this.loggingConsumers, this.samplingHandler)); diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java index e5d964b7a..41f71d054 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java @@ -179,11 +179,10 @@ public void removeRoot(String rootUri) { } /** - * Send a synchronous ping request. - * @return + * Send a synchronous ping request to the server. */ - public Object ping() { - return this.delegate.ping().block(); + public void ping() { + this.delegate.ping().block(); } // -------------------------- diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java index 661c629ea..969c3a866 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java @@ -59,7 +59,11 @@ protected void onStart() { protected void onClose() { } - protected Duration getTimeoutDuration() { + protected Duration getRequestTimeout() { + return Duration.ofSeconds(10); + } + + protected Duration getInitializationTimeout() { return Duration.ofSeconds(2); } @@ -70,7 +74,8 @@ void setUp() { assertThatCode(() -> { mcpAsyncClient = McpClient.async(mcpTransport) - .requestTimeout(getTimeoutDuration()) + .requestTimeout(getRequestTimeout()) + .initializationTimeout(getInitializationTimeout()) .capabilities(ClientCapabilities.builder().roots(true).build()) .build(); }).doesNotThrowAnyException(); @@ -79,105 +84,110 @@ void setUp() { @AfterEach void tearDown() { if (mcpAsyncClient != null) { - assertThatCode(() -> mcpAsyncClient.closeGracefully().block(Duration.ofSeconds(10))) - .doesNotThrowAnyException(); + StepVerifier.create(mcpAsyncClient.closeGracefully()).verifyComplete(); } onClose(); } @Test void testConstructorWithInvalidArguments() { - assertThatThrownBy(() -> McpClient.sync(null).build()).isInstanceOf(IllegalArgumentException.class) + assertThatThrownBy(() -> McpClient.async(null).build()).isInstanceOf(IllegalArgumentException.class) .hasMessage("Transport must not be null"); - assertThatThrownBy(() -> McpClient.sync(mcpTransport).requestTimeout(null).build()) + assertThatThrownBy(() -> McpClient.async(mcpTransport).requestTimeout(null).build()) .isInstanceOf(IllegalArgumentException.class) .hasMessage("Request timeout must not be null"); } @Test void testListToolsWithoutInitialization() { - assertThatThrownBy(() -> mcpAsyncClient.listTools(null).block()).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before listing tools"); + StepVerifier.create(mcpAsyncClient.listTools(null)).expectErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class) + .hasMessage("Client must be initialized before listing tools"); + }).verify(); } @Test void testListTools() { - mcpAsyncClient.initialize().block(Duration.ofSeconds(10)); - - StepVerifier.create(mcpAsyncClient.listTools(null)).consumeNextWith(result -> { - assertThat(result.tools()).isNotNull().isNotEmpty(); + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listTools(null))) + .consumeNextWith(result -> { + assertThat(result.tools()).isNotNull().isNotEmpty(); - Tool firstTool = result.tools().get(0); - assertThat(firstTool.name()).isNotNull(); - assertThat(firstTool.description()).isNotNull(); - }).verifyComplete(); + Tool firstTool = result.tools().get(0); + assertThat(firstTool.name()).isNotNull(); + assertThat(firstTool.description()).isNotNull(); + }) + .verifyComplete(); } @Test void testPingWithoutInitialization() { - assertThatThrownBy(() -> mcpAsyncClient.ping().block()).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before pinging the server"); + StepVerifier.create(mcpAsyncClient.ping()) + .expectErrorMatches(error -> error instanceof McpError + && error.getMessage().equals("Client must be initialized before pinging the server")) + .verify(); } @Test void testPing() { - mcpAsyncClient.initialize().block(Duration.ofSeconds(10)); - assertThatCode(() -> mcpAsyncClient.ping().block()).doesNotThrowAnyException(); + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.ping())).verifyComplete(); } @Test void testCallToolWithoutInitialization() { CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", ECHO_TEST_MESSAGE)); - assertThatThrownBy(() -> mcpAsyncClient.callTool(callToolRequest).block()).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before calling tools"); + StepVerifier.create(mcpAsyncClient.callTool(callToolRequest)) + .expectErrorMatches(error -> error instanceof McpError + && error.getMessage().equals("Client must be initialized before calling tools")) + .verify(); } @Test void testCallTool() { - mcpAsyncClient.initialize().block(Duration.ofSeconds(10)); - CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", ECHO_TEST_MESSAGE)); - StepVerifier.create(mcpAsyncClient.callTool(callToolRequest)).consumeNextWith(callToolResult -> { - assertThat(callToolResult).isNotNull().satisfies(result -> { - assertThat(result.content()).isNotNull(); - assertThat(result.isError()).isNull(); - }); - }).verifyComplete(); + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.callTool(callToolRequest))) + .consumeNextWith(callToolResult -> { + assertThat(callToolResult).isNotNull(); + assertThat(callToolResult.content()).isNotNull(); + assertThat(callToolResult.isError()).isNull(); + }) + .verifyComplete(); } @Test void testCallToolWithInvalidTool() { - mcpAsyncClient.initialize().block(Duration.ofSeconds(10)); - CallToolRequest invalidRequest = new CallToolRequest("nonexistent_tool", Map.of("message", ECHO_TEST_MESSAGE)); - assertThatThrownBy(() -> mcpAsyncClient.callTool(invalidRequest).block()).isInstanceOf(Exception.class); + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.callTool(invalidRequest))) + .expectError(Exception.class) + .verify(); } @Test void testListResourcesWithoutInitialization() { - assertThatThrownBy(() -> mcpAsyncClient.listResources(null).block()).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before listing resources"); + StepVerifier.create(mcpAsyncClient.listResources(null)) + .expectErrorMatches(error -> error instanceof McpError + && error.getMessage().equals("Client must be initialized before listing resources")) + .verify(); } @Test void testListResources() { - mcpAsyncClient.initialize().block(Duration.ofSeconds(10)); - - StepVerifier.create(mcpAsyncClient.listResources(null)).consumeNextWith(resources -> { - assertThat(resources).isNotNull().satisfies(result -> { - assertThat(result.resources()).isNotNull(); - - if (!result.resources().isEmpty()) { - Resource firstResource = result.resources().get(0); - assertThat(firstResource.uri()).isNotNull(); - assertThat(firstResource.name()).isNotNull(); - } - }); - }).verifyComplete(); + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResources(null))) + .consumeNextWith(resources -> { + assertThat(resources).isNotNull().satisfies(result -> { + assertThat(result.resources()).isNotNull(); + + if (!result.resources().isEmpty()) { + Resource firstResource = result.resources().get(0); + assertThat(firstResource.uri()).isNotNull(); + assertThat(firstResource.name()).isNotNull(); + } + }); + }) + .verifyComplete(); } @Test @@ -187,40 +197,44 @@ void testMcpAsyncClientState() { @Test void testListPromptsWithoutInitialization() { - assertThatThrownBy(() -> mcpAsyncClient.listPrompts(null).block()).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before listing prompts"); + StepVerifier.create(mcpAsyncClient.listPrompts(null)) + .expectErrorMatches(error -> error instanceof McpError + && error.getMessage().equals("Client must be initialized before listing prompts")) + .verify(); } @Test void testListPrompts() { - mcpAsyncClient.initialize().block(Duration.ofSeconds(10)); - - StepVerifier.create(mcpAsyncClient.listPrompts(null)).consumeNextWith(prompts -> { - assertThat(prompts).isNotNull().satisfies(result -> { - assertThat(result.prompts()).isNotNull(); - - if (!result.prompts().isEmpty()) { - Prompt firstPrompt = result.prompts().get(0); - assertThat(firstPrompt.name()).isNotNull(); - assertThat(firstPrompt.description()).isNotNull(); - } - }); - }).verifyComplete(); + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listPrompts(null))) + .consumeNextWith(prompts -> { + assertThat(prompts).isNotNull().satisfies(result -> { + assertThat(result.prompts()).isNotNull(); + + if (!result.prompts().isEmpty()) { + Prompt firstPrompt = result.prompts().get(0); + assertThat(firstPrompt.name()).isNotNull(); + assertThat(firstPrompt.description()).isNotNull(); + } + }); + }) + .verifyComplete(); } @Test void testGetPromptWithoutInitialization() { GetPromptRequest request = new GetPromptRequest("simple_prompt", Map.of()); - assertThatThrownBy(() -> mcpAsyncClient.getPrompt(request).block()).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before getting prompts"); + StepVerifier.create(mcpAsyncClient.getPrompt(request)) + .expectErrorMatches(error -> error instanceof McpError + && error.getMessage().equals("Client must be initialized before getting prompts")) + .verify(); } @Test void testGetPrompt() { - mcpAsyncClient.initialize().block(Duration.ofSeconds(10)); + GetPromptRequest request = new GetPromptRequest("simple_prompt", Map.of()); - StepVerifier.create(mcpAsyncClient.getPrompt(new GetPromptRequest("simple_prompt", Map.of()))) + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.getPrompt(request))) .consumeNextWith(prompt -> { assertThat(prompt).isNotNull().satisfies(result -> { assertThat(result.messages()).isNotEmpty(); @@ -232,15 +246,16 @@ void testGetPrompt() { @Test void testRootsListChangedWithoutInitialization() { - assertThatThrownBy(() -> mcpAsyncClient.rootsListChangedNotification().block()).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before sending roots list changed notification"); + StepVerifier.create(mcpAsyncClient.rootsListChangedNotification()) + .expectErrorMatches(error -> error instanceof McpError && error.getMessage() + .equals("Client must be initialized before sending roots list changed notification")) + .verify(); } @Test void testRootsListChanged() { - mcpAsyncClient.initialize().block(Duration.ofSeconds(10)); - - assertThatCode(() -> mcpAsyncClient.rootsListChangedNotification().block()).doesNotThrowAnyException(); + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.rootsListChangedNotification())) + .verifyComplete(); } @Test @@ -248,39 +263,39 @@ void testInitializeWithRootsListProviders() { var transport = createMcpTransport(); var client = McpClient.async(transport) - .requestTimeout(getTimeoutDuration()) + .requestTimeout(getRequestTimeout()) .roots(new Root("file:///test/path", "test-root")) .build(); - assertThatCode(() -> client.initialize().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - - assertThatCode(() -> client.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + StepVerifier.create(client.initialize().then(client.closeGracefully())).verifyComplete(); } @Test void testAddRoot() { Root newRoot = new Root("file:///new/test/path", "new-test-root"); - assertThatCode(() -> mcpAsyncClient.addRoot(newRoot).block()).doesNotThrowAnyException(); + + StepVerifier.create(mcpAsyncClient.addRoot(newRoot)).verifyComplete(); } @Test void testAddRootWithNullValue() { - assertThatThrownBy(() -> mcpAsyncClient.addRoot(null).block()).hasMessageContaining("Root must not be null"); + StepVerifier.create(mcpAsyncClient.addRoot(null)) + .expectErrorMatches(error -> error.getMessage().contains("Root must not be null")) + .verify(); } @Test void testRemoveRoot() { Root root = new Root("file:///test/path/to/remove", "root-to-remove"); - assertThatCode(() -> { - mcpAsyncClient.addRoot(root).block(); - mcpAsyncClient.removeRoot(root.uri()).block(); - }).doesNotThrowAnyException(); + + StepVerifier.create(mcpAsyncClient.addRoot(root).then(mcpAsyncClient.removeRoot(root.uri()))).verifyComplete(); } @Test void testRemoveNonExistentRoot() { - assertThatThrownBy(() -> mcpAsyncClient.removeRoot("nonexistent-uri").block()) - .hasMessageContaining("Root with uri 'nonexistent-uri' not found"); + StepVerifier.create(mcpAsyncClient.removeRoot("nonexistent-uri")) + .expectErrorMatches(error -> error.getMessage().contains("Root with uri 'nonexistent-uri' not found")) + .verify(); } @Test @@ -299,18 +314,20 @@ void testReadResource() { @Test void testListResourceTemplatesWithoutInitialization() { - assertThatThrownBy(() -> mcpAsyncClient.listResourceTemplates().block()).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before listing resource templates"); + StepVerifier.create(mcpAsyncClient.listResourceTemplates()) + .expectErrorMatches(error -> error instanceof McpError + && error.getMessage().equals("Client must be initialized before listing resource templates")) + .verify(); } @Test void testListResourceTemplates() { - mcpAsyncClient.initialize().block(Duration.ofSeconds(10)); - - StepVerifier.create(mcpAsyncClient.listResourceTemplates()).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.resourceTemplates()).isNotNull(); - }).verifyComplete(); + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResourceTemplates())) + .consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.resourceTemplates()).isNotNull(); + }) + .verifyComplete(); } // @Test @@ -338,16 +355,13 @@ void testNotificationHandlers() { var transport = createMcpTransport(); var client = McpClient.async(transport) - .requestTimeout(getTimeoutDuration()) + .requestTimeout(getRequestTimeout()) .toolsChangeConsumer(tools -> Mono.fromRunnable(() -> toolsNotificationReceived.set(true))) .resourcesChangeConsumer(resources -> Mono.fromRunnable(() -> resourcesNotificationReceived.set(true))) .promptsChangeConsumer(prompts -> Mono.fromRunnable(() -> promptsNotificationReceived.set(true))) .build(); - assertThatCode(() -> { - client.initialize().block(); - client.closeGracefully().block(); - }).doesNotThrowAnyException(); + StepVerifier.create(client.initialize().then(client.closeGracefully())).verifyComplete(); } @Test @@ -357,15 +371,12 @@ void testInitializeWithSamplingCapability() { var capabilities = ClientCapabilities.builder().sampling().build(); var client = McpClient.async(transport) - .requestTimeout(getTimeoutDuration()) + .requestTimeout(getRequestTimeout()) .capabilities(capabilities) .sampling(request -> Mono.just(CreateMessageResult.builder().message("test").model("test-model").build())) .build(); - assertThatCode(() -> { - client.initialize().block(Duration.ofSeconds(10)); - client.closeGracefully().block(Duration.ofSeconds(10)); - }).doesNotThrowAnyException(); + StepVerifier.create(client.initialize().then(client.closeGracefully())).verifyComplete(); } @Test @@ -381,17 +392,17 @@ void testInitializeWithAllCapabilities() { Function> samplingHandler = request -> Mono .just(CreateMessageResult.builder().message("test").model("test-model").build()); var client = McpClient.async(transport) - .requestTimeout(getTimeoutDuration()) + .requestTimeout(getRequestTimeout()) .capabilities(capabilities) .sampling(samplingHandler) .build(); - assertThatCode(() -> { - var result = client.initialize().block(Duration.ofSeconds(10)); + StepVerifier.create(client.initialize()).consumeNextWith(result -> { assertThat(result).isNotNull(); assertThat(result.capabilities()).isNotNull(); - client.closeGracefully().block(Duration.ofSeconds(10)); - }).doesNotThrowAnyException(); + }).verifyComplete(); + + StepVerifier.create(client.closeGracefully()).verifyComplete(); } // --------------------------------------- @@ -400,19 +411,23 @@ void testInitializeWithAllCapabilities() { @Test void testLoggingLevelsWithoutInitialization() { - assertThatThrownBy(() -> mcpAsyncClient.setLoggingLevel(McpSchema.LoggingLevel.DEBUG).block()) - .isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before setting logging level"); + StepVerifier.create(mcpAsyncClient.setLoggingLevel(McpSchema.LoggingLevel.DEBUG)) + .expectErrorMatches(error -> error instanceof McpError + && error.getMessage().equals("Client must be initialized before setting logging level")) + .verify(); } @Test void testLoggingLevels() { - mcpAsyncClient.initialize().block(Duration.ofSeconds(10)); + Mono testAllLevels = mcpAsyncClient.initialize().then(Mono.defer(() -> { + Mono chain = Mono.empty(); + for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { + chain = chain.then(mcpAsyncClient.setLoggingLevel(level)); + } + return chain; + })); - // Test all logging levels - for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { - StepVerifier.create(mcpAsyncClient.setLoggingLevel(level)).verifyComplete(); - } + StepVerifier.create(testAllLevels).verifyComplete(); } @Test @@ -421,20 +436,18 @@ void testLoggingConsumer() { var transport = createMcpTransport(); var client = McpClient.async(transport) - .requestTimeout(getTimeoutDuration()) + .requestTimeout(getRequestTimeout()) .loggingConsumer(notification -> Mono.fromRunnable(() -> logReceived.set(true))) .build(); - assertThatCode(() -> { - client.initialize().block(Duration.ofSeconds(10)); - client.closeGracefully().block(Duration.ofSeconds(10)); - }).doesNotThrowAnyException(); + StepVerifier.create(client.initialize().then(client.closeGracefully())).verifyComplete(); } @Test void testLoggingWithNullNotification() { - assertThatThrownBy(() -> mcpAsyncClient.setLoggingLevel(null).block()) - .hasMessageContaining("Logging level must not be null"); + StepVerifier.create(mcpAsyncClient.setLoggingLevel(null)) + .expectErrorMatches(error -> error.getMessage().contains("Logging level must not be null")) + .verify(); } } diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java index 6f8cf198e..a866bfb32 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java @@ -53,7 +53,11 @@ public abstract class AbstractMcpSyncClientTests { abstract protected void onClose(); - protected Duration getTimeoutDuration() { + protected Duration getRequestTimeout() { + return Duration.ofSeconds(10); + } + + protected Duration getInitializationTimeout() { return Duration.ofSeconds(2); } @@ -64,7 +68,8 @@ void setUp() { assertThatCode(() -> { mcpSyncClient = McpClient.sync(mcpTransport) - .requestTimeout(getTimeoutDuration()) + .requestTimeout(getRequestTimeout()) + .initializationTimeout(getInitializationTimeout()) .capabilities(ClientCapabilities.builder().roots(true).build()) .build(); }).doesNotThrowAnyException(); @@ -216,7 +221,7 @@ void testInitializeWithRootsListProviders() { var transport = createMcpTransport(); var client = McpClient.sync(transport) - .requestTimeout(getTimeoutDuration()) + .requestTimeout(getRequestTimeout()) .roots(new Root("file:///test/path", "test-root")) .build(); @@ -314,7 +319,7 @@ void testNotificationHandlers() { var transport = createMcpTransport(); var client = McpClient.sync(transport) - .requestTimeout(getTimeoutDuration()) + .requestTimeout(getRequestTimeout()) .toolsChangeConsumer(tools -> toolsNotificationReceived.set(true)) .resourcesChangeConsumer(resources -> resourcesNotificationReceived.set(true)) .promptsChangeConsumer(prompts -> promptsNotificationReceived.set(true)) @@ -352,7 +357,7 @@ void testLoggingConsumer() { var transport = createMcpTransport(); var client = McpClient.sync(transport) - .requestTimeout(getTimeoutDuration()) + .requestTimeout(getRequestTimeout()) .loggingConsumer(notification -> logReceived.set(true)) .build(); diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/ServletSseMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java similarity index 89% rename from mcp/src/test/java/io/modelcontextprotocol/client/ServletSseMcpAsyncClientTests.java rename to mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java index 7cc673fa1..ac0fef24d 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/ServletSseMcpAsyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java @@ -18,7 +18,7 @@ * @author Christian Tzolov */ @Timeout(15) // Giving extra time beyond the client timeout -class ServletSseMcpAsyncClientTests extends AbstractMcpAsyncClientTests { +class HttpSseMcpAsyncClientTests extends AbstractMcpAsyncClientTests { String host = "http://localhost:3004"; @@ -46,9 +46,4 @@ protected void onClose() { container.stop(); } - @Override - protected Duration getTimeoutDuration() { - return Duration.ofMillis(300); - } - } diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/ServletSseMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java similarity index 89% rename from mcp/src/test/java/io/modelcontextprotocol/client/ServletSseMcpSyncClientTests.java rename to mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java index 2b8af41af..8772e6208 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/ServletSseMcpSyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java @@ -18,7 +18,7 @@ * @author Christian Tzolov */ @Timeout(15) // Giving extra time beyond the client timeout -class ServletSseMcpSyncClientTests extends AbstractMcpSyncClientTests { +class HttpSseMcpSyncClientTests extends AbstractMcpSyncClientTests { String host = "http://localhost:3003"; @@ -46,9 +46,4 @@ protected void onClose() { container.stop(); } - @Override - protected Duration getTimeoutDuration() { - return Duration.ofMillis(300); - } - } diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java index 7ae65253a..3517008c3 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java @@ -39,7 +39,7 @@ void customErrorHandlerShouldReceiveErrors() { ((StdioClientTransport) mcpTransport).setStdErrorHandler(error -> receivedError.set(error)); String errorMessage = "Test error"; - ((StdioClientTransport) mcpTransport).getErrorSink().tryEmitNext(errorMessage); + ((StdioClientTransport) mcpTransport).getErrorSink().emitNext(errorMessage, null); assertThat(receivedError.get()).isNotNull().isEqualTo(errorMessage); } From 914a14a92a53009125fbe85e61dba5937a8cd66f Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Wed, 19 Mar 2025 08:27:38 +0100 Subject: [PATCH 004/222] adjust test timeout values Signed-off-by: Christian Tzolov --- .../client/WebFluxSseMcpAsyncClientTests.java | 6 ++++++ .../client/WebFluxSseMcpSyncClientTests.java | 6 ++++++ .../client/AbstractMcpAsyncClientTests.java | 6 +++--- .../client/AbstractMcpSyncClientTests.java | 8 +++++--- .../client/AbstractMcpSyncClientTests.java | 6 ++++-- .../client/StdioMcpAsyncClientTests.java | 6 ++++++ .../client/StdioMcpSyncClientTests.java | 9 +++------ 7 files changed, 33 insertions(+), 14 deletions(-) diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java index 021ce4654..0dccb27a7 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java @@ -4,6 +4,8 @@ package io.modelcontextprotocol.client; +import java.time.Duration; + import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; import io.modelcontextprotocol.spec.ClientMcpTransport; import org.junit.jupiter.api.Timeout; @@ -46,4 +48,8 @@ public void onClose() { container.stop(); } + protected Duration getInitializationTimeout() { + return Duration.ofSeconds(1); + } + } diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java index 20eeb1d59..f5cab7b73 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java @@ -4,6 +4,8 @@ package io.modelcontextprotocol.client; +import java.time.Duration; + import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; import io.modelcontextprotocol.spec.ClientMcpTransport; import org.junit.jupiter.api.Timeout; @@ -46,4 +48,8 @@ protected void onClose() { container.stop(); } + protected Duration getInitializationTimeout() { + return Duration.ofSeconds(1); + } + } diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java index 17cc99608..2aa659cab 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java @@ -63,7 +63,7 @@ protected Duration getRequestTimeout() { } protected Duration getInitializationTimeout() { - return Duration.ofSeconds(1); + return Duration.ofSeconds(2); } @BeforeEach @@ -90,10 +90,10 @@ void tearDown() { @Test void testConstructorWithInvalidArguments() { - assertThatThrownBy(() -> McpClient.sync(null).build()).isInstanceOf(IllegalArgumentException.class) + assertThatThrownBy(() -> McpClient.async(null).build()).isInstanceOf(IllegalArgumentException.class) .hasMessage("Transport must not be null"); - assertThatThrownBy(() -> McpClient.sync(mcpTransport).requestTimeout(null).build()) + assertThatThrownBy(() -> McpClient.async(mcpTransport).requestTimeout(null).build()) .isInstanceOf(IllegalArgumentException.class) .hasMessage("Request timeout must not be null"); } diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java index ee43a572e..d1b752fcb 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java @@ -48,16 +48,18 @@ public abstract class AbstractMcpSyncClientTests { abstract protected ClientMcpTransport createMcpTransport(); - abstract protected void onStart(); + protected void onStart() { + } - abstract protected void onClose(); + protected void onClose() { + } protected Duration getRequestTimeout() { return Duration.ofSeconds(10); } protected Duration getInitializationTimeout() { - return Duration.ofSeconds(1); + return Duration.ofSeconds(2); } @BeforeEach diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java index a866bfb32..726632f3c 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java @@ -49,9 +49,11 @@ public abstract class AbstractMcpSyncClientTests { abstract protected ClientMcpTransport createMcpTransport(); - abstract protected void onStart(); + protected void onStart() { + } - abstract protected void onClose(); + protected void onClose() { + } protected Duration getRequestTimeout() { return Duration.ofSeconds(10); diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java index ce74812b7..c285e2c65 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java @@ -4,6 +4,8 @@ package io.modelcontextprotocol.client; +import java.time.Duration; + import io.modelcontextprotocol.client.transport.ServerParameters; import io.modelcontextprotocol.client.transport.StdioClientTransport; import io.modelcontextprotocol.spec.ClientMcpTransport; @@ -26,4 +28,8 @@ protected ClientMcpTransport createMcpTransport() { return new StdioClientTransport(stdioParams); } + protected Duration getInitializationTimeout() { + return Duration.ofSeconds(6); + } + } diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java index 3517008c3..ec351623e 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java @@ -4,6 +4,7 @@ package io.modelcontextprotocol.client; +import java.time.Duration; import java.util.concurrent.atomic.AtomicReference; import io.modelcontextprotocol.client.transport.ServerParameters; @@ -44,12 +45,8 @@ void customErrorHandlerShouldReceiveErrors() { assertThat(receivedError.get()).isNotNull().isEqualTo(errorMessage); } - @Override - protected void onStart() { - } - - @Override - protected void onClose() { + protected Duration getInitializationTimeout() { + return Duration.ofSeconds(6); } } From 264a0c90412558884dee9ed1db1c9bd9d4047afa Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Wed, 19 Mar 2025 10:04:37 +0100 Subject: [PATCH 005/222] Address review comments Signed-off-by: Christian Tzolov --- .../client/AbstractMcpAsyncClientTests.java | 3 ++- .../io/modelcontextprotocol/client/McpAsyncClient.java | 7 +++---- .../java/io/modelcontextprotocol/client/McpSyncClient.java | 7 ++++--- .../client/AbstractMcpAsyncClientTests.java | 3 ++- .../client/StdioMcpSyncClientTests.java | 3 ++- 5 files changed, 13 insertions(+), 10 deletions(-) diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java index 2aa659cab..91dd223c5 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java @@ -129,7 +129,8 @@ void testPingWithoutInitialization() { @Test void testPing() { - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.ping())).verifyComplete(); + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.ping())).consumeNextWith(callToolResult -> { + }).verifyComplete(); } @Test diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java index 4c5fd02ca..278e360d6 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java @@ -369,13 +369,12 @@ private Mono withInitializationCheck(String actionName, /** * Sends a ping request to the server. - * @return A Mono that completes when the server responds to the ping + * @return A Mono that completes with the server's ping response */ - public Mono ping() { + public Mono ping() { return this.withInitializationCheck("pinging the server", initializedResult -> this.mcpSession .sendRequest(McpSchema.METHOD_PING, null, new TypeReference() { - }) - .then()); + })); } // -------------------------- diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java index 41f71d054..e5d964b7a 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java @@ -179,10 +179,11 @@ public void removeRoot(String rootUri) { } /** - * Send a synchronous ping request to the server. + * Send a synchronous ping request. + * @return */ - public void ping() { - this.delegate.ping().block(); + public Object ping() { + return this.delegate.ping().block(); } // -------------------------- diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java index 969c3a866..1bc40c52e 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java @@ -130,7 +130,8 @@ void testPingWithoutInitialization() { @Test void testPing() { - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.ping())).verifyComplete(); + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.ping())).consumeNextWith(callToolResult -> { + }).verifyComplete(); } @Test diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java index ec351623e..6d759b4ba 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java @@ -12,6 +12,7 @@ import io.modelcontextprotocol.spec.ClientMcpTransport; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; +import reactor.core.publisher.Sinks; import static org.assertj.core.api.Assertions.assertThat; @@ -40,7 +41,7 @@ void customErrorHandlerShouldReceiveErrors() { ((StdioClientTransport) mcpTransport).setStdErrorHandler(error -> receivedError.set(error)); String errorMessage = "Test error"; - ((StdioClientTransport) mcpTransport).getErrorSink().emitNext(errorMessage, null); + ((StdioClientTransport) mcpTransport).getErrorSink().emitNext(errorMessage, Sinks.EmitFailureHandler.FAIL_FAST); assertThat(receivedError.get()).isNotNull().isEqualTo(errorMessage); } From 92ec67a9650cc1c81f09de71c569345895d251fe Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Wed, 19 Mar 2025 10:26:14 +0100 Subject: [PATCH 006/222] Increase the request timeout to 14 sec Signed-off-by: Christian Tzolov --- .../client/AbstractMcpAsyncClientTests.java | 2 +- .../modelcontextprotocol/client/AbstractMcpSyncClientTests.java | 2 +- .../client/AbstractMcpAsyncClientTests.java | 2 +- .../modelcontextprotocol/client/AbstractMcpSyncClientTests.java | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java index 91dd223c5..a8a59a63a 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java @@ -59,7 +59,7 @@ protected void onClose() { } protected Duration getRequestTimeout() { - return Duration.ofSeconds(10); + return Duration.ofSeconds(14); } protected Duration getInitializationTimeout() { diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java index d1b752fcb..0f83e31e3 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java @@ -55,7 +55,7 @@ protected void onClose() { } protected Duration getRequestTimeout() { - return Duration.ofSeconds(10); + return Duration.ofSeconds(14); } protected Duration getInitializationTimeout() { diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java index 1bc40c52e..39bc49953 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java @@ -60,7 +60,7 @@ protected void onClose() { } protected Duration getRequestTimeout() { - return Duration.ofSeconds(10); + return Duration.ofSeconds(14); } protected Duration getInitializationTimeout() { diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java index 726632f3c..52a0138f6 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java @@ -56,7 +56,7 @@ protected void onClose() { } protected Duration getRequestTimeout() { - return Duration.ofSeconds(10); + return Duration.ofSeconds(14); } protected Duration getInitializationTimeout() { From 37120f2ae585ee68e74b3c6734239797b84b4fe2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Wed, 19 Mar 2025 15:08:43 +0100 Subject: [PATCH 007/222] Improve client test reliability and execution time This change uses VirtualTimeScheduler and pretends enough time has passed to trigger a timeout on the initialization. Another problem with reliability of the tests was that the used testcontainer for the SSE server does not support multiple clients and the existence of both the global client for the entire suite and some customized local clients in some tests caused responses to be delivered to the other client at some racing situations. Now each test creates a dedicated client and performs cleanup locally. While these tests were improved, two other issues were found and fixed. The first one is that the closeGracefully of DefaultMcpSession was not lazy and would trigger connection disposal before the returned Mono was subscribed. The second one was dealing with closing the StdIo client before the process was started. In such a case there should not be an error but rather a warning and successful completion. --- .../client/AbstractMcpAsyncClientTests.java | 513 +++++++++++------- .../client/AbstractMcpSyncClientTests.java | 363 ++++++++----- .../transport/StdioClientTransport.java | 7 +- .../spec/DefaultMcpSession.java | 6 +- .../client/AbstractMcpAsyncClientTests.java | 462 +++++++++------- .../client/AbstractMcpSyncClientTests.java | 363 ++++++++----- .../client/StdioMcpSyncClientTests.java | 20 +- 7 files changed, 1016 insertions(+), 718 deletions(-) diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java index a8a59a63a..033139adc 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java @@ -6,7 +6,10 @@ import java.time.Duration; import java.util.Map; +import java.util.Objects; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; import java.util.function.Function; import io.modelcontextprotocol.spec.ClientMcpTransport; @@ -44,10 +47,6 @@ */ public abstract class AbstractMcpAsyncClientTests { - private McpAsyncClient mcpAsyncClient; - - protected ClientMcpTransport mcpTransport; - private static final String ECHO_TEST_MESSAGE = "Hello MCP Spring AI!"; abstract protected ClientMcpTransport createMcpTransport(); @@ -66,25 +65,47 @@ protected Duration getInitializationTimeout() { return Duration.ofSeconds(2); } - @BeforeEach - void setUp() { - onStart(); - this.mcpTransport = createMcpTransport(); + McpAsyncClient client(ClientMcpTransport transport) { + return client(transport, Function.identity()); + } + + McpAsyncClient client(ClientMcpTransport transport, Function customizer) { + AtomicReference client = new AtomicReference<>(); assertThatCode(() -> { - mcpAsyncClient = McpClient.async(mcpTransport) + McpClient.AsyncSpec builder = McpClient.async(transport) .requestTimeout(getRequestTimeout()) .initializationTimeout(getInitializationTimeout()) - .capabilities(ClientCapabilities.builder().roots(true).build()) - .build(); + .capabilities(ClientCapabilities.builder().roots(true).build()); + builder = customizer.apply(builder); + client.set(builder.build()); }).doesNotThrowAnyException(); + + return client.get(); + } + + void withClient(ClientMcpTransport transport, Consumer c) { + withClient(transport, Function.identity(), c); + } + + void withClient(ClientMcpTransport transport, Function customizer, + Consumer c) { + var client = client(transport, customizer); + try { + c.accept(client); + } + finally { + StepVerifier.create(client.closeGracefully()).expectComplete().verify(Duration.ofSeconds(10)); + } + } + + @BeforeEach + void setUp() { + onStart(); } @AfterEach void tearDown() { - if (mcpAsyncClient != null) { - StepVerifier.create(mcpAsyncClient.closeGracefully()).verifyComplete(); - } onClose(); } @@ -93,258 +114,323 @@ void testConstructorWithInvalidArguments() { assertThatThrownBy(() -> McpClient.async(null).build()).isInstanceOf(IllegalArgumentException.class) .hasMessage("Transport must not be null"); - assertThatThrownBy(() -> McpClient.async(mcpTransport).requestTimeout(null).build()) + assertThatThrownBy(() -> McpClient.async(createMcpTransport()).requestTimeout(null).build()) .isInstanceOf(IllegalArgumentException.class) .hasMessage("Request timeout must not be null"); } @Test void testListToolsWithoutInitialization() { - StepVerifier.create(mcpAsyncClient.listTools(null)).expectErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before listing tools"); - }).verify(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.withVirtualTime(() -> mcpAsyncClient.listTools(null)) + .expectSubscription() + .thenAwait(getInitializationTimeout()) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be initialized before listing tools")) + .verify(); + }); } @Test void testListTools() { - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listTools(null))) - .consumeNextWith(result -> { - assertThat(result.tools()).isNotNull().isNotEmpty(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listTools(null))) + .consumeNextWith(result -> { + assertThat(result.tools()).isNotNull().isNotEmpty(); - Tool firstTool = result.tools().get(0); - assertThat(firstTool.name()).isNotNull(); - assertThat(firstTool.description()).isNotNull(); - }) - .verifyComplete(); + Tool firstTool = result.tools().get(0); + assertThat(firstTool.name()).isNotNull(); + assertThat(firstTool.description()).isNotNull(); + }) + .verifyComplete(); + }); } @Test void testPingWithoutInitialization() { - StepVerifier.create(mcpAsyncClient.ping()) - .expectErrorMatches(error -> error instanceof McpError - && error.getMessage().equals("Client must be initialized before pinging the server")) - .verify(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.withVirtualTime(() -> mcpAsyncClient.ping()) + .expectSubscription() + .thenAwait(getInitializationTimeout()) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be initialized before pinging the " + "server")) + .verify(); + }); } @Test void testPing() { - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.ping())).consumeNextWith(callToolResult -> { - }).verifyComplete(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.ping())) + .expectNextCount(1) + .verifyComplete(); + }); } @Test void testCallToolWithoutInitialization() { - CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", ECHO_TEST_MESSAGE)); + withClient(createMcpTransport(), mcpAsyncClient -> { + CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", ECHO_TEST_MESSAGE)); - StepVerifier.create(mcpAsyncClient.callTool(callToolRequest)) - .expectErrorMatches(error -> error instanceof McpError - && error.getMessage().equals("Client must be initialized before calling tools")) - .verify(); + StepVerifier.withVirtualTime(() -> mcpAsyncClient.callTool(callToolRequest)) + .expectSubscription() + .thenAwait(getInitializationTimeout()) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be initialized before calling tools")) + .verify(); + }); } @Test void testCallTool() { - CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", ECHO_TEST_MESSAGE)); + withClient(createMcpTransport(), mcpAsyncClient -> { + CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", ECHO_TEST_MESSAGE)); - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.callTool(callToolRequest))) - .consumeNextWith(callToolResult -> { - assertThat(callToolResult).isNotNull(); - assertThat(callToolResult.content()).isNotNull(); - assertThat(callToolResult.isError()).isNull(); - }) - .verifyComplete(); + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.callTool(callToolRequest))) + .consumeNextWith(callToolResult -> { + assertThat(callToolResult).isNotNull().satisfies(result -> { + assertThat(result.content()).isNotNull(); + assertThat(result.isError()).isNull(); + }); + }) + .verifyComplete(); + }); } @Test void testCallToolWithInvalidTool() { - CallToolRequest invalidRequest = new CallToolRequest("nonexistent_tool", Map.of("message", ECHO_TEST_MESSAGE)); + withClient(createMcpTransport(), mcpAsyncClient -> { + CallToolRequest invalidRequest = new CallToolRequest("nonexistent_tool", + Map.of("message", ECHO_TEST_MESSAGE)); - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.callTool(invalidRequest))) - .expectError(Exception.class) - .verify(); + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.callTool(invalidRequest))) + .consumeErrorWith( + e -> assertThat(e).isInstanceOf(McpError.class).hasMessage("Unknown tool: nonexistent_tool")) + .verify(); + }); } @Test void testListResourcesWithoutInitialization() { - StepVerifier.create(mcpAsyncClient.listResources(null)) - .expectErrorMatches(error -> error instanceof McpError - && error.getMessage().equals("Client must be initialized before listing resources")) - .verify(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.withVirtualTime(() -> mcpAsyncClient.listResources(null)) + .expectSubscription() + .thenAwait(getInitializationTimeout()) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be initialized before listing resources")) + .verify(); + }); } @Test void testListResources() { - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResources(null))) - .consumeNextWith(resources -> { - assertThat(resources).isNotNull().satisfies(result -> { - assertThat(result.resources()).isNotNull(); - - if (!result.resources().isEmpty()) { - Resource firstResource = result.resources().get(0); - assertThat(firstResource.uri()).isNotNull(); - assertThat(firstResource.name()).isNotNull(); - } - }); - }) - .verifyComplete(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResources(null))) + .consumeNextWith(resources -> { + assertThat(resources).isNotNull().satisfies(result -> { + assertThat(result.resources()).isNotNull(); + + if (!result.resources().isEmpty()) { + Resource firstResource = result.resources().get(0); + assertThat(firstResource.uri()).isNotNull(); + assertThat(firstResource.name()).isNotNull(); + } + }); + }) + .verifyComplete(); + }); } @Test void testMcpAsyncClientState() { - assertThat(mcpAsyncClient).isNotNull(); + withClient(createMcpTransport(), mcpAsyncClient -> { + assertThat(mcpAsyncClient).isNotNull(); + }); } @Test void testListPromptsWithoutInitialization() { - StepVerifier.create(mcpAsyncClient.listPrompts(null)) - .expectErrorMatches(error -> error instanceof McpError - && error.getMessage().equals("Client must be initialized before listing prompts")) - .verify(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.withVirtualTime(() -> mcpAsyncClient.listPrompts(null)) + .expectSubscription() + .thenAwait(getInitializationTimeout()) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be initialized before listing prompts")) + .verify(); + }); } @Test void testListPrompts() { - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listPrompts(null))) - .consumeNextWith(prompts -> { - assertThat(prompts).isNotNull().satisfies(result -> { - assertThat(result.prompts()).isNotNull(); - - if (!result.prompts().isEmpty()) { - Prompt firstPrompt = result.prompts().get(0); - assertThat(firstPrompt.name()).isNotNull(); - assertThat(firstPrompt.description()).isNotNull(); - } - }); - }) - .verifyComplete(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listPrompts(null))) + .consumeNextWith(prompts -> { + assertThat(prompts).isNotNull().satisfies(result -> { + assertThat(result.prompts()).isNotNull(); + + if (!result.prompts().isEmpty()) { + Prompt firstPrompt = result.prompts().get(0); + assertThat(firstPrompt.name()).isNotNull(); + assertThat(firstPrompt.description()).isNotNull(); + } + }); + }) + .verifyComplete(); + }); } @Test void testGetPromptWithoutInitialization() { - GetPromptRequest request = new GetPromptRequest("simple_prompt", Map.of()); + withClient(createMcpTransport(), mcpAsyncClient -> { + GetPromptRequest request = new GetPromptRequest("simple_prompt", Map.of()); - StepVerifier.create(mcpAsyncClient.getPrompt(request)) - .expectErrorMatches(error -> error instanceof McpError - && error.getMessage().equals("Client must be initialized before getting prompts")) - .verify(); + StepVerifier.withVirtualTime(() -> mcpAsyncClient.getPrompt(request)) + .expectSubscription() + .thenAwait(getInitializationTimeout()) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be initialized before getting prompts")) + .verify(); + }); } @Test void testGetPrompt() { - GetPromptRequest request = new GetPromptRequest("simple_prompt", Map.of()); - - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.getPrompt(request))) - .consumeNextWith(prompt -> { - assertThat(prompt).isNotNull().satisfies(result -> { - assertThat(result.messages()).isNotEmpty(); - assertThat(result.messages()).hasSize(1); - }); - }) - .verifyComplete(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier + .create(mcpAsyncClient.initialize() + .then(mcpAsyncClient.getPrompt(new GetPromptRequest("simple_prompt", Map.of())))) + .consumeNextWith(prompt -> { + assertThat(prompt).isNotNull().satisfies(result -> { + assertThat(result.messages()).isNotEmpty(); + assertThat(result.messages()).hasSize(1); + }); + }) + .verifyComplete(); + }); } @Test void testRootsListChangedWithoutInitialization() { - StepVerifier.create(mcpAsyncClient.rootsListChangedNotification()) - .expectErrorMatches(error -> error instanceof McpError && error.getMessage() - .equals("Client must be initialized before sending roots list changed notification")) - .verify(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.withVirtualTime(() -> mcpAsyncClient.rootsListChangedNotification()) + .expectSubscription() + .thenAwait(getInitializationTimeout()) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be initialized before sending roots list changed notification")) + .verify(); + }); } @Test void testRootsListChanged() { - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.rootsListChangedNotification())) - .verifyComplete(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.rootsListChangedNotification())) + .verifyComplete(); + }); } @Test void testInitializeWithRootsListProviders() { - var transport = createMcpTransport(); - - var client = McpClient.async(transport) - .requestTimeout(getRequestTimeout()) - .roots(new Root("file:///test/path", "test-root")) - .build(); - - StepVerifier.create(client.initialize().then(client.closeGracefully())).verifyComplete(); + withClient(createMcpTransport(), builder -> builder.roots(new Root("file:///test/path", "test-root")), + client -> { + StepVerifier.create(client.initialize().then(client.closeGracefully())).verifyComplete(); + }); } @Test void testAddRoot() { - Root newRoot = new Root("file:///new/test/path", "new-test-root"); - - StepVerifier.create(mcpAsyncClient.addRoot(newRoot)).verifyComplete(); + withClient(createMcpTransport(), mcpAsyncClient -> { + Root newRoot = new Root("file:///new/test/path", "new-test-root"); + StepVerifier.create(mcpAsyncClient.addRoot(newRoot)).verifyComplete(); + }); } @Test void testAddRootWithNullValue() { - StepVerifier.create(mcpAsyncClient.addRoot(null)) - .expectErrorMatches(error -> error.getMessage().contains("Root must not be null")) - .verify(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.addRoot(null)) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class).hasMessage("Root must not be null")) + .verify(); + }); } @Test void testRemoveRoot() { - Root root = new Root("file:///test/path/to/remove", "root-to-remove"); + withClient(createMcpTransport(), mcpAsyncClient -> { + Root root = new Root("file:///test/path/to/remove", "root-to-remove"); + StepVerifier.create(mcpAsyncClient.addRoot(root)).verifyComplete(); - StepVerifier.create(mcpAsyncClient.addRoot(root).then(mcpAsyncClient.removeRoot(root.uri()))).verifyComplete(); + StepVerifier.create(mcpAsyncClient.removeRoot(root.uri())).verifyComplete(); + }); } @Test void testRemoveNonExistentRoot() { - StepVerifier.create(mcpAsyncClient.removeRoot("nonexistent-uri")) - .expectErrorMatches(error -> error.getMessage().contains("Root with uri 'nonexistent-uri' not found")) - .verify(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.removeRoot("nonexistent-uri")) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Root with uri 'nonexistent-uri' not found")) + .verify(); + }); } @Test @Disabled void testReadResource() { - StepVerifier.create(mcpAsyncClient.listResources()).consumeNextWith(resources -> { - if (!resources.resources().isEmpty()) { - Resource firstResource = resources.resources().get(0); - StepVerifier.create(mcpAsyncClient.readResource(firstResource)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.contents()).isNotNull(); - }).verifyComplete(); - } - }).verifyComplete(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.listResources()).consumeNextWith(resources -> { + if (!resources.resources().isEmpty()) { + Resource firstResource = resources.resources().get(0); + StepVerifier.create(mcpAsyncClient.readResource(firstResource)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.contents()).isNotNull(); + }).verifyComplete(); + } + }).verifyComplete(); + }); } @Test void testListResourceTemplatesWithoutInitialization() { - StepVerifier.create(mcpAsyncClient.listResourceTemplates()) - .expectErrorMatches(error -> error instanceof McpError - && error.getMessage().equals("Client must be initialized before listing resource templates")) - .verify(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.withVirtualTime(() -> mcpAsyncClient.listResourceTemplates()) + .expectSubscription() + .thenAwait(getInitializationTimeout()) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be initialized before listing resource templates")) + .verify(); + }); } @Test void testListResourceTemplates() { - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResourceTemplates())) - .consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.resourceTemplates()).isNotNull(); - }) - .verifyComplete(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResourceTemplates())) + .consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.resourceTemplates()).isNotNull(); + }) + .verifyComplete(); + }); } // @Test void testResourceSubscription() { - StepVerifier.create(mcpAsyncClient.listResources()).consumeNextWith(resources -> { - if (!resources.resources().isEmpty()) { - Resource firstResource = resources.resources().get(0); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.listResources()).consumeNextWith(resources -> { + if (!resources.resources().isEmpty()) { + Resource firstResource = resources.resources().get(0); - // Test subscribe - StepVerifier.create(mcpAsyncClient.subscribeResource(new SubscribeRequest(firstResource.uri()))) - .verifyComplete(); + // Test subscribe + StepVerifier.create(mcpAsyncClient.subscribeResource(new SubscribeRequest(firstResource.uri()))) + .verifyComplete(); - // Test unsubscribe - StepVerifier.create(mcpAsyncClient.unsubscribeResource(new UnsubscribeRequest(firstResource.uri()))) - .verifyComplete(); - } - }).verifyComplete(); + // Test unsubscribe + StepVerifier.create(mcpAsyncClient.unsubscribeResource(new UnsubscribeRequest(firstResource.uri()))) + .verifyComplete(); + } + }).verifyComplete(); + }); } @Test @@ -353,36 +439,44 @@ void testNotificationHandlers() { AtomicBoolean resourcesNotificationReceived = new AtomicBoolean(false); AtomicBoolean promptsNotificationReceived = new AtomicBoolean(false); - var transport = createMcpTransport(); - var client = McpClient.async(transport) - .requestTimeout(getRequestTimeout()) - .toolsChangeConsumer(tools -> Mono.fromRunnable(() -> toolsNotificationReceived.set(true))) - .resourcesChangeConsumer(resources -> Mono.fromRunnable(() -> resourcesNotificationReceived.set(true))) - .promptsChangeConsumer(prompts -> Mono.fromRunnable(() -> promptsNotificationReceived.set(true))) - .build(); - - StepVerifier.create(client.initialize().then(client.closeGracefully())).verifyComplete(); + withClient(createMcpTransport(), + builder -> builder + .toolsChangeConsumer(tools -> Mono.fromRunnable(() -> toolsNotificationReceived.set(true))) + .resourcesChangeConsumer( + resources -> Mono.fromRunnable(() -> resourcesNotificationReceived.set(true))) + .promptsChangeConsumer(prompts -> Mono.fromRunnable(() -> promptsNotificationReceived.set(true))), + mcpAsyncClient -> { + + var transport = createMcpTransport(); + var client = McpClient.async(transport) + .requestTimeout(getRequestTimeout()) + .toolsChangeConsumer(tools -> Mono.fromRunnable(() -> toolsNotificationReceived.set(true))) + .resourcesChangeConsumer( + resources -> Mono.fromRunnable(() -> resourcesNotificationReceived.set(true))) + .promptsChangeConsumer( + prompts -> Mono.fromRunnable(() -> promptsNotificationReceived.set(true))) + .build(); + + StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete(); + }); } @Test void testInitializeWithSamplingCapability() { - var transport = createMcpTransport(); - - var capabilities = ClientCapabilities.builder().sampling().build(); - - var client = McpClient.async(transport) - .requestTimeout(getRequestTimeout()) - .capabilities(capabilities) - .sampling(request -> Mono.just(CreateMessageResult.builder().message("test").model("test-model").build())) + ClientCapabilities capabilities = ClientCapabilities.builder().sampling().build(); + CreateMessageResult createMessageResult = CreateMessageResult.builder() + .message("test") + .model("test-model") .build(); - - StepVerifier.create(client.initialize().then(client.closeGracefully())).verifyComplete(); + withClient(createMcpTransport(), + builder -> builder.capabilities(capabilities).sampling(request -> Mono.just(createMessageResult)), + client -> { + StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete(); + }); } @Test void testInitializeWithAllCapabilities() { - var transport = createMcpTransport(); - var capabilities = ClientCapabilities.builder() .experimental(Map.of("feature", "test")) .roots(true) @@ -391,18 +485,14 @@ void testInitializeWithAllCapabilities() { Function> samplingHandler = request -> Mono .just(CreateMessageResult.builder().message("test").model("test-model").build()); - var client = McpClient.async(transport) - .requestTimeout(getRequestTimeout()) - .capabilities(capabilities) - .sampling(samplingHandler) - .build(); - StepVerifier.create(client.initialize()).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.capabilities()).isNotNull(); - }).verifyComplete(); + withClient(createMcpTransport(), builder -> builder.capabilities(capabilities).sampling(samplingHandler), + client -> - StepVerifier.create(client.closeGracefully()).verifyComplete(); + StepVerifier.create(client.initialize()).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.capabilities()).isNotNull(); + }).verifyComplete()); } // --------------------------------------- @@ -411,43 +501,52 @@ void testInitializeWithAllCapabilities() { @Test void testLoggingLevelsWithoutInitialization() { - StepVerifier.create(mcpAsyncClient.setLoggingLevel(McpSchema.LoggingLevel.DEBUG)) - .expectErrorMatches(error -> error instanceof McpError - && error.getMessage().equals("Client must be initialized before setting logging level")) - .verify(); + withClient(createMcpTransport(), + mcpAsyncClient -> StepVerifier + .withVirtualTime(() -> mcpAsyncClient.setLoggingLevel(McpSchema.LoggingLevel.DEBUG)) + .expectSubscription() + .thenAwait(getInitializationTimeout()) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be initialized before setting logging level")) + .verify()); } @Test void testLoggingLevels() { - Mono testAllLevels = mcpAsyncClient.initialize().then(Mono.defer(() -> { - Mono chain = Mono.empty(); - for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { - chain = chain.then(mcpAsyncClient.setLoggingLevel(level)); - } - return chain; - })); + withClient(createMcpTransport(), mcpAsyncClient -> { + Mono testAllLevels = mcpAsyncClient.initialize().then(Mono.defer(() -> { + Mono chain = Mono.empty(); + for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { + chain = chain.then(mcpAsyncClient.setLoggingLevel(level)); + } + return chain; + })); - StepVerifier.create(testAllLevels).verifyComplete(); + StepVerifier.create(testAllLevels).verifyComplete(); + }); } @Test void testLoggingConsumer() { AtomicBoolean logReceived = new AtomicBoolean(false); - var transport = createMcpTransport(); - var client = McpClient.async(transport) - .requestTimeout(getRequestTimeout()) - .loggingConsumer(notification -> Mono.fromRunnable(() -> logReceived.set(true))) - .build(); + withClient(createMcpTransport(), + builder -> builder.loggingConsumer(notification -> Mono.fromRunnable(() -> logReceived.set(true))), + client -> { + StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete(); + StepVerifier.create(client.closeGracefully()).verifyComplete(); + + }); - StepVerifier.create(client.initialize().then(client.closeGracefully())).verifyComplete(); } @Test void testLoggingWithNullNotification() { - StepVerifier.create(mcpAsyncClient.setLoggingLevel(null)) - .expectErrorMatches(error -> error.getMessage().contains("Logging level must not be null")) - .verify(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.setLoggingLevel(null)) + .expectErrorMatches(error -> error.getMessage().contains("Logging level must not be null")) + .verify(); + }); } } diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java index 0f83e31e3..032f8684a 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java @@ -7,6 +7,9 @@ import java.time.Duration; import java.util.Map; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.function.Function; import io.modelcontextprotocol.spec.ClientMcpTransport; import io.modelcontextprotocol.spec.McpError; @@ -27,6 +30,10 @@ import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Scheduler; +import reactor.core.scheduler.Schedulers; +import reactor.test.StepVerifier; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; @@ -40,12 +47,8 @@ */ public abstract class AbstractMcpSyncClientTests { - private McpSyncClient mcpSyncClient; - private static final String TEST_MESSAGE = "Hello MCP Spring AI!"; - protected ClientMcpTransport mcpTransport; - abstract protected ClientMcpTransport createMcpTransport(); protected void onStart() { @@ -62,254 +65,322 @@ protected Duration getInitializationTimeout() { return Duration.ofSeconds(2); } - @BeforeEach - void setUp() { - onStart(); - this.mcpTransport = createMcpTransport(); + McpSyncClient client(ClientMcpTransport transport) { + return client(transport, Function.identity()); + } + + McpSyncClient client(ClientMcpTransport transport, Function customizer) { + AtomicReference client = new AtomicReference<>(); assertThatCode(() -> { - mcpSyncClient = McpClient.sync(mcpTransport) + McpClient.SyncSpec builder = McpClient.sync(transport) .requestTimeout(getRequestTimeout()) .initializationTimeout(getInitializationTimeout()) - .capabilities(ClientCapabilities.builder().roots(true).build()) - .build(); + .capabilities(ClientCapabilities.builder().roots(true).build()); + builder = customizer.apply(builder); + client.set(builder.build()); }).doesNotThrowAnyException(); + + return client.get(); + } + + void withClient(ClientMcpTransport transport, Consumer c) { + withClient(transport, Function.identity(), c); + } + + void withClient(ClientMcpTransport transport, Function customizer, + Consumer c) { + var client = client(transport, customizer); + try { + c.accept(client); + } + finally { + assertThat(client.closeGracefully()).isTrue(); + } + } + + @BeforeEach + void setUp() { + onStart(); + } @AfterEach void tearDown() { - if (mcpSyncClient != null) { - assertThatCode(() -> mcpSyncClient.close()).doesNotThrowAnyException(); - } onClose(); } + static final Object DUMMY_RETURN_VALUE = new Object(); + + void verifyNotificationTimesOut(Consumer operation, String action) { + verifyCallTimesOut(client -> { + operation.accept(client); + return DUMMY_RETURN_VALUE; + }, action); + } + + void verifyCallTimesOut(Function operation, String action) { + withClient(createMcpTransport(), mcpSyncClient -> { + // This scheduler is not replaced by virtual time scheduler + Scheduler customScheduler = Schedulers.newBoundedElastic(1, 1, "actualBoundedElastic"); + + StepVerifier.withVirtualTime(() -> Mono.fromSupplier(() -> operation.apply(mcpSyncClient)) + // offload the blocking call to the real scheduler + .subscribeOn(customScheduler)) + .expectSubscription() + .thenAwait(getInitializationTimeout()) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be initialized before " + action)) + .verify(); + + customScheduler.dispose(); + }); + } + @Test void testConstructorWithInvalidArguments() { assertThatThrownBy(() -> McpClient.sync(null).build()).isInstanceOf(IllegalArgumentException.class) .hasMessage("Transport must not be null"); - assertThatThrownBy(() -> McpClient.sync(mcpTransport).requestTimeout(null).build()) + assertThatThrownBy(() -> McpClient.sync(createMcpTransport()).requestTimeout(null).build()) .isInstanceOf(IllegalArgumentException.class) .hasMessage("Request timeout must not be null"); } @Test void testListToolsWithoutInitialization() { - assertThatThrownBy(() -> mcpSyncClient.listTools(null)).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before listing tools"); + verifyCallTimesOut(client -> client.listTools(null), "listing tools"); } @Test void testListTools() { - mcpSyncClient.initialize(); - ListToolsResult tools = mcpSyncClient.listTools(null); + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + ListToolsResult tools = mcpSyncClient.listTools(null); - assertThat(tools).isNotNull().satisfies(result -> { - assertThat(result.tools()).isNotNull().isNotEmpty(); + assertThat(tools).isNotNull().satisfies(result -> { + assertThat(result.tools()).isNotNull().isNotEmpty(); - Tool firstTool = result.tools().get(0); - assertThat(firstTool.name()).isNotNull(); - assertThat(firstTool.description()).isNotNull(); + Tool firstTool = result.tools().get(0); + assertThat(firstTool.name()).isNotNull(); + assertThat(firstTool.description()).isNotNull(); + }); }); } @Test void testCallToolsWithoutInitialization() { - assertThatThrownBy(() -> mcpSyncClient.callTool(new CallToolRequest("add", Map.of("a", 3, "b", 4)))) - .isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before calling tools"); + verifyCallTimesOut(client -> client.callTool(new CallToolRequest("add", Map.of("a", 3, "b", 4))), + "calling tools"); } @Test void testCallTools() { - mcpSyncClient.initialize(); - CallToolResult toolResult = mcpSyncClient.callTool(new CallToolRequest("add", Map.of("a", 3, "b", 4))); + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + CallToolResult toolResult = mcpSyncClient.callTool(new CallToolRequest("add", Map.of("a", 3, "b", 4))); - assertThat(toolResult).isNotNull().satisfies(result -> { + assertThat(toolResult).isNotNull().satisfies(result -> { - assertThat(result.content()).hasSize(1); + assertThat(result.content()).hasSize(1); - TextContent content = (TextContent) result.content().get(0); + TextContent content = (TextContent) result.content().get(0); - assertThat(content).isNotNull(); - assertThat(content.text()).isNotNull(); - assertThat(content.text()).contains("7"); + assertThat(content).isNotNull(); + assertThat(content.text()).isNotNull(); + assertThat(content.text()).contains("7"); + }); }); } @Test void testPingWithoutInitialization() { - assertThatThrownBy(() -> mcpSyncClient.ping()).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before pinging the server"); + verifyCallTimesOut(client -> client.ping(), "pinging the server"); } @Test void testPing() { - mcpSyncClient.initialize(); - assertThatCode(() -> mcpSyncClient.ping()).doesNotThrowAnyException(); + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + assertThatCode(() -> mcpSyncClient.ping()).doesNotThrowAnyException(); + }); } @Test void testCallToolWithoutInitialization() { CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", TEST_MESSAGE)); - - assertThatThrownBy(() -> mcpSyncClient.callTool(callToolRequest)).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before calling tools"); + verifyCallTimesOut(client -> client.callTool(callToolRequest), "calling tools"); } @Test void testCallTool() { - mcpSyncClient.initialize(); - CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", TEST_MESSAGE)); + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", TEST_MESSAGE)); - CallToolResult callToolResult = mcpSyncClient.callTool(callToolRequest); + CallToolResult callToolResult = mcpSyncClient.callTool(callToolRequest); - assertThat(callToolResult).isNotNull().satisfies(result -> { - assertThat(result.content()).isNotNull(); - assertThat(result.isError()).isNull(); + assertThat(callToolResult).isNotNull().satisfies(result -> { + assertThat(result.content()).isNotNull(); + assertThat(result.isError()).isNull(); + }); }); } @Test void testCallToolWithInvalidTool() { - CallToolRequest invalidRequest = new CallToolRequest("nonexistent_tool", Map.of("message", TEST_MESSAGE)); + withClient(createMcpTransport(), mcpSyncClient -> { + CallToolRequest invalidRequest = new CallToolRequest("nonexistent_tool", Map.of("message", TEST_MESSAGE)); - assertThatThrownBy(() -> mcpSyncClient.callTool(invalidRequest)).isInstanceOf(Exception.class); + assertThatThrownBy(() -> mcpSyncClient.callTool(invalidRequest)).isInstanceOf(Exception.class); + }); } @Test void testRootsListChangedWithoutInitialization() { - assertThatThrownBy(() -> mcpSyncClient.rootsListChangedNotification()).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before sending roots list changed notification"); + verifyNotificationTimesOut(client -> client.rootsListChangedNotification(), + "sending roots list changed notification"); } @Test void testRootsListChanged() { - mcpSyncClient.initialize(); - assertThatCode(() -> mcpSyncClient.rootsListChangedNotification()).doesNotThrowAnyException(); + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + assertThatCode(() -> mcpSyncClient.rootsListChangedNotification()).doesNotThrowAnyException(); + }); } @Test void testListResourcesWithoutInitialization() { - assertThatThrownBy(() -> mcpSyncClient.listResources(null)).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before listing resources"); + verifyCallTimesOut(client -> client.listResources(null), "listing resources"); } @Test void testListResources() { - mcpSyncClient.initialize(); - ListResourcesResult resources = mcpSyncClient.listResources(null); - - assertThat(resources).isNotNull().satisfies(result -> { - assertThat(result.resources()).isNotNull(); - - if (!result.resources().isEmpty()) { - Resource firstResource = result.resources().get(0); - assertThat(firstResource.uri()).isNotNull(); - assertThat(firstResource.name()).isNotNull(); - } + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + ListResourcesResult resources = mcpSyncClient.listResources(null); + + assertThat(resources).isNotNull().satisfies(result -> { + assertThat(result.resources()).isNotNull(); + + if (!result.resources().isEmpty()) { + Resource firstResource = result.resources().get(0); + assertThat(firstResource.uri()).isNotNull(); + assertThat(firstResource.name()).isNotNull(); + } + }); }); } @Test void testClientSessionState() { - assertThat(mcpSyncClient).isNotNull(); + withClient(createMcpTransport(), mcpSyncClient -> { + assertThat(mcpSyncClient).isNotNull(); + }); } @Test void testInitializeWithRootsListProviders() { - var transport = createMcpTransport(); - - var client = McpClient.sync(transport) - .requestTimeout(getRequestTimeout()) - .roots(new Root("file:///test/path", "test-root")) - .build(); + withClient(createMcpTransport(), builder -> builder.roots(new Root("file:///test/path", "test-root")), + mcpSyncClient -> { - assertThatCode(() -> { - client.initialize(); - client.close(); - }).doesNotThrowAnyException(); + assertThatCode(() -> { + mcpSyncClient.initialize(); + mcpSyncClient.close(); + }).doesNotThrowAnyException(); + }); } @Test void testAddRoot() { - Root newRoot = new Root("file:///new/test/path", "new-test-root"); - assertThatCode(() -> mcpSyncClient.addRoot(newRoot)).doesNotThrowAnyException(); + withClient(createMcpTransport(), mcpSyncClient -> { + Root newRoot = new Root("file:///new/test/path", "new-test-root"); + assertThatCode(() -> mcpSyncClient.addRoot(newRoot)).doesNotThrowAnyException(); + }); } @Test void testAddRootWithNullValue() { - assertThatThrownBy(() -> mcpSyncClient.addRoot(null)).hasMessageContaining("Root must not be null"); + withClient(createMcpTransport(), mcpSyncClient -> { + assertThatThrownBy(() -> mcpSyncClient.addRoot(null)).hasMessageContaining("Root must not be null"); + }); } @Test void testRemoveRoot() { - Root root = new Root("file:///test/path/to/remove", "root-to-remove"); - assertThatCode(() -> { - mcpSyncClient.addRoot(root); - mcpSyncClient.removeRoot(root.uri()); - }).doesNotThrowAnyException(); + withClient(createMcpTransport(), mcpSyncClient -> { + Root root = new Root("file:///test/path/to/remove", "root-to-remove"); + assertThatCode(() -> { + mcpSyncClient.addRoot(root); + mcpSyncClient.removeRoot(root.uri()); + }).doesNotThrowAnyException(); + }); } @Test void testRemoveNonExistentRoot() { - assertThatThrownBy(() -> mcpSyncClient.removeRoot("nonexistent-uri")) - .hasMessageContaining("Root with uri 'nonexistent-uri' not found"); + withClient(createMcpTransport(), mcpSyncClient -> { + assertThatThrownBy(() -> mcpSyncClient.removeRoot("nonexistent-uri")) + .hasMessageContaining("Root with uri 'nonexistent-uri' not found"); + }); } @Test void testReadResourceWithoutInitialization() { - assertThatThrownBy(() -> { - Resource resource = new Resource("test://uri", "Test Resource", null, null, null); - mcpSyncClient.readResource(resource); - }).isInstanceOf(McpError.class).hasMessage("Client must be initialized before reading resources"); + Resource resource = new Resource("test://uri", "Test Resource", null, null, null); + verifyCallTimesOut(client -> client.readResource(resource), "reading resources"); } @Test void testReadResource() { - mcpSyncClient.initialize(); - ListResourcesResult resources = mcpSyncClient.listResources(null); + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + ListResourcesResult resources = mcpSyncClient.listResources(null); - if (!resources.resources().isEmpty()) { - Resource firstResource = resources.resources().get(0); - ReadResourceResult result = mcpSyncClient.readResource(firstResource); + if (!resources.resources().isEmpty()) { + Resource firstResource = resources.resources().get(0); + ReadResourceResult result = mcpSyncClient.readResource(firstResource); - assertThat(result).isNotNull(); - assertThat(result.contents()).isNotNull(); - } + assertThat(result).isNotNull(); + assertThat(result.contents()).isNotNull(); + } + }); } @Test void testListResourceTemplatesWithoutInitialization() { - assertThatThrownBy(() -> mcpSyncClient.listResourceTemplates(null)).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before listing resource templates"); + verifyCallTimesOut(client -> client.listResourceTemplates(null), "listing resource templates"); } @Test void testListResourceTemplates() { - mcpSyncClient.initialize(); - ListResourceTemplatesResult result = mcpSyncClient.listResourceTemplates(null); + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + ListResourceTemplatesResult result = mcpSyncClient.listResourceTemplates(null); - assertThat(result).isNotNull(); - assertThat(result.resourceTemplates()).isNotNull(); + assertThat(result).isNotNull(); + assertThat(result.resourceTemplates()).isNotNull(); + }); } // @Test void testResourceSubscription() { - ListResourcesResult resources = mcpSyncClient.listResources(null); + withClient(createMcpTransport(), mcpSyncClient -> { + ListResourcesResult resources = mcpSyncClient.listResources(null); - if (!resources.resources().isEmpty()) { - Resource firstResource = resources.resources().get(0); + if (!resources.resources().isEmpty()) { + Resource firstResource = resources.resources().get(0); - // Test subscribe - assertThatCode(() -> mcpSyncClient.subscribeResource(new SubscribeRequest(firstResource.uri()))) - .doesNotThrowAnyException(); + // Test subscribe + assertThatCode(() -> mcpSyncClient.subscribeResource(new SubscribeRequest(firstResource.uri()))) + .doesNotThrowAnyException(); - // Test unsubscribe - assertThatCode(() -> mcpSyncClient.unsubscribeResource(new UnsubscribeRequest(firstResource.uri()))) - .doesNotThrowAnyException(); - } + // Test unsubscribe + assertThatCode(() -> mcpSyncClient.unsubscribeResource(new UnsubscribeRequest(firstResource.uri()))) + .doesNotThrowAnyException(); + } + }); } @Test @@ -318,18 +389,17 @@ void testNotificationHandlers() { AtomicBoolean resourcesNotificationReceived = new AtomicBoolean(false); AtomicBoolean promptsNotificationReceived = new AtomicBoolean(false); - var transport = createMcpTransport(); - var client = McpClient.sync(transport) - .requestTimeout(getRequestTimeout()) - .toolsChangeConsumer(tools -> toolsNotificationReceived.set(true)) - .resourcesChangeConsumer(resources -> resourcesNotificationReceived.set(true)) - .promptsChangeConsumer(prompts -> promptsNotificationReceived.set(true)) - .build(); + withClient(createMcpTransport(), + builder -> builder.toolsChangeConsumer(tools -> toolsNotificationReceived.set(true)) + .resourcesChangeConsumer(resources -> resourcesNotificationReceived.set(true)) + .promptsChangeConsumer(prompts -> promptsNotificationReceived.set(true)), + client -> { - assertThatCode(() -> { - client.initialize(); - client.close(); - }).doesNotThrowAnyException(); + assertThatCode(() -> { + client.initialize(); + client.close(); + }).doesNotThrowAnyException(); + }); } // --------------------------------------- @@ -338,40 +408,37 @@ void testNotificationHandlers() { @Test void testLoggingLevelsWithoutInitialization() { - assertThatThrownBy(() -> mcpSyncClient.setLoggingLevel(McpSchema.LoggingLevel.DEBUG)) - .isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before setting logging level"); + verifyNotificationTimesOut(client -> client.setLoggingLevel(McpSchema.LoggingLevel.DEBUG), + "setting logging level"); } @Test void testLoggingLevels() { - mcpSyncClient.initialize(); - // Test all logging levels - for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { - assertThatCode(() -> mcpSyncClient.setLoggingLevel(level)).doesNotThrowAnyException(); - } + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + // Test all logging levels + for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { + assertThatCode(() -> mcpSyncClient.setLoggingLevel(level)).doesNotThrowAnyException(); + } + }); } @Test void testLoggingConsumer() { AtomicBoolean logReceived = new AtomicBoolean(false); - var transport = createMcpTransport(); - - var client = McpClient.sync(transport) - .requestTimeout(getRequestTimeout()) - .loggingConsumer(notification -> logReceived.set(true)) - .build(); - - assertThatCode(() -> { - client.initialize(); - client.close(); - }).doesNotThrowAnyException(); + withClient(createMcpTransport(), builder -> builder.requestTimeout(getRequestTimeout()) + .loggingConsumer(notification -> logReceived.set(true)), client -> { + assertThatCode(() -> { + client.initialize(); + client.close(); + }).doesNotThrowAnyException(); + }); } @Test void testLoggingWithNullNotification() { - assertThatThrownBy(() -> mcpSyncClient.setLoggingLevel(null)) - .hasMessageContaining("Logging level must not be null"); + withClient(createMcpTransport(), mcpSyncClient -> assertThatThrownBy(() -> mcpSyncClient.setLoggingLevel(null)) + .hasMessageContaining("Logging level must not be null")); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java index 614c65125..d35db3f89 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java @@ -353,14 +353,15 @@ public Mono closeGracefully() { // Give a short time for any pending messages to be processed return Mono.delay(Duration.ofMillis(100)); - })).then(Mono.fromFuture(() -> { + })).then(Mono.defer(() -> { logger.debug("Sending TERM to process"); if (this.process != null) { this.process.destroy(); - return process.onExit(); + return Mono.fromFuture(process.onExit()); } else { - return CompletableFuture.failedFuture(new RuntimeException("Process not started")); + logger.warn("Process not started"); + return Mono.empty(); } })).doOnNext(process -> { if (process.exitValue() != 0) { diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpSession.java index e2d354f4a..46aefafcb 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpSession.java @@ -270,8 +270,10 @@ public Mono sendNotification(String method, Map params) { */ @Override public Mono closeGracefully() { - this.connection.dispose(); - return transport.closeGracefully(); + return Mono.defer(() -> { + this.connection.dispose(); + return transport.closeGracefully(); + }); } /** diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java index 39bc49953..720388545 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java @@ -6,8 +6,12 @@ import java.time.Duration; import java.util.Map; +import java.util.Objects; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; import java.util.function.Function; +import java.util.function.Supplier; import io.modelcontextprotocol.spec.ClientMcpTransport; import io.modelcontextprotocol.spec.McpError; @@ -45,10 +49,6 @@ // KEEP IN SYNC with the class in mcp-test module public abstract class AbstractMcpAsyncClientTests { - private McpAsyncClient mcpAsyncClient; - - protected ClientMcpTransport mcpTransport; - private static final String ECHO_TEST_MESSAGE = "Hello MCP Spring AI!"; abstract protected ClientMcpTransport createMcpTransport(); @@ -67,285 +67,326 @@ protected Duration getInitializationTimeout() { return Duration.ofSeconds(2); } - @BeforeEach - void setUp() { - onStart(); - this.mcpTransport = createMcpTransport(); + McpAsyncClient client(ClientMcpTransport transport) { + return client(transport, Function.identity()); + } + + McpAsyncClient client(ClientMcpTransport transport, Function customizer) { + AtomicReference client = new AtomicReference<>(); assertThatCode(() -> { - mcpAsyncClient = McpClient.async(mcpTransport) + McpClient.AsyncSpec builder = McpClient.async(transport) .requestTimeout(getRequestTimeout()) .initializationTimeout(getInitializationTimeout()) - .capabilities(ClientCapabilities.builder().roots(true).build()) - .build(); + .capabilities(ClientCapabilities.builder().roots(true).build()); + builder = customizer.apply(builder); + client.set(builder.build()); }).doesNotThrowAnyException(); + + return client.get(); + } + + void withClient(ClientMcpTransport transport, Consumer c) { + withClient(transport, Function.identity(), c); + } + + void withClient(ClientMcpTransport transport, Function customizer, + Consumer c) { + var client = client(transport, customizer); + try { + c.accept(client); + } + finally { + StepVerifier.create(client.closeGracefully()).expectComplete().verify(Duration.ofSeconds(10)); + } + } + + @BeforeEach + void setUp() { + onStart(); } @AfterEach void tearDown() { - if (mcpAsyncClient != null) { - StepVerifier.create(mcpAsyncClient.closeGracefully()).verifyComplete(); - } onClose(); } + void verifyInitializationTimeout(Function> operation, String action) { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.withVirtualTime(() -> operation.apply(mcpAsyncClient)) + .expectSubscription() + .thenAwait(getInitializationTimeout()) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be initialized before " + action)) + .verify(); + }); + } + @Test void testConstructorWithInvalidArguments() { assertThatThrownBy(() -> McpClient.async(null).build()).isInstanceOf(IllegalArgumentException.class) .hasMessage("Transport must not be null"); - assertThatThrownBy(() -> McpClient.async(mcpTransport).requestTimeout(null).build()) + assertThatThrownBy(() -> McpClient.async(createMcpTransport()).requestTimeout(null).build()) .isInstanceOf(IllegalArgumentException.class) .hasMessage("Request timeout must not be null"); } @Test void testListToolsWithoutInitialization() { - StepVerifier.create(mcpAsyncClient.listTools(null)).expectErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before listing tools"); - }).verify(); + verifyInitializationTimeout(client -> client.listTools(null), "listing tools"); } @Test void testListTools() { - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listTools(null))) - .consumeNextWith(result -> { - assertThat(result.tools()).isNotNull().isNotEmpty(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listTools(null))) + .consumeNextWith(result -> { + assertThat(result.tools()).isNotNull().isNotEmpty(); - Tool firstTool = result.tools().get(0); - assertThat(firstTool.name()).isNotNull(); - assertThat(firstTool.description()).isNotNull(); - }) - .verifyComplete(); + Tool firstTool = result.tools().get(0); + assertThat(firstTool.name()).isNotNull(); + assertThat(firstTool.description()).isNotNull(); + }) + .verifyComplete(); + }); } @Test void testPingWithoutInitialization() { - StepVerifier.create(mcpAsyncClient.ping()) - .expectErrorMatches(error -> error instanceof McpError - && error.getMessage().equals("Client must be initialized before pinging the server")) - .verify(); + verifyInitializationTimeout(client -> client.ping(), "pinging the server"); } @Test void testPing() { - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.ping())).consumeNextWith(callToolResult -> { - }).verifyComplete(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.ping())) + .expectNextCount(1) + .verifyComplete(); + }); } @Test void testCallToolWithoutInitialization() { CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", ECHO_TEST_MESSAGE)); - - StepVerifier.create(mcpAsyncClient.callTool(callToolRequest)) - .expectErrorMatches(error -> error instanceof McpError - && error.getMessage().equals("Client must be initialized before calling tools")) - .verify(); + verifyInitializationTimeout(client -> client.callTool(callToolRequest), "calling tools"); } @Test void testCallTool() { - CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", ECHO_TEST_MESSAGE)); + withClient(createMcpTransport(), mcpAsyncClient -> { + CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", ECHO_TEST_MESSAGE)); - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.callTool(callToolRequest))) - .consumeNextWith(callToolResult -> { - assertThat(callToolResult).isNotNull(); - assertThat(callToolResult.content()).isNotNull(); - assertThat(callToolResult.isError()).isNull(); - }) - .verifyComplete(); + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.callTool(callToolRequest))) + .consumeNextWith(callToolResult -> { + assertThat(callToolResult).isNotNull().satisfies(result -> { + assertThat(result.content()).isNotNull(); + assertThat(result.isError()).isNull(); + }); + }) + .verifyComplete(); + }); } @Test void testCallToolWithInvalidTool() { - CallToolRequest invalidRequest = new CallToolRequest("nonexistent_tool", Map.of("message", ECHO_TEST_MESSAGE)); + withClient(createMcpTransport(), mcpAsyncClient -> { + CallToolRequest invalidRequest = new CallToolRequest("nonexistent_tool", + Map.of("message", ECHO_TEST_MESSAGE)); - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.callTool(invalidRequest))) - .expectError(Exception.class) - .verify(); + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.callTool(invalidRequest))) + .consumeErrorWith( + e -> assertThat(e).isInstanceOf(McpError.class).hasMessage("Unknown tool: nonexistent_tool")) + .verify(); + }); } @Test void testListResourcesWithoutInitialization() { - StepVerifier.create(mcpAsyncClient.listResources(null)) - .expectErrorMatches(error -> error instanceof McpError - && error.getMessage().equals("Client must be initialized before listing resources")) - .verify(); + verifyInitializationTimeout(client -> client.listResources(null), "listing resources"); } @Test void testListResources() { - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResources(null))) - .consumeNextWith(resources -> { - assertThat(resources).isNotNull().satisfies(result -> { - assertThat(result.resources()).isNotNull(); - - if (!result.resources().isEmpty()) { - Resource firstResource = result.resources().get(0); - assertThat(firstResource.uri()).isNotNull(); - assertThat(firstResource.name()).isNotNull(); - } - }); - }) - .verifyComplete(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResources(null))) + .consumeNextWith(resources -> { + assertThat(resources).isNotNull().satisfies(result -> { + assertThat(result.resources()).isNotNull(); + + if (!result.resources().isEmpty()) { + Resource firstResource = result.resources().get(0); + assertThat(firstResource.uri()).isNotNull(); + assertThat(firstResource.name()).isNotNull(); + } + }); + }) + .verifyComplete(); + }); } @Test void testMcpAsyncClientState() { - assertThat(mcpAsyncClient).isNotNull(); + withClient(createMcpTransport(), mcpAsyncClient -> { + assertThat(mcpAsyncClient).isNotNull(); + }); } @Test void testListPromptsWithoutInitialization() { - StepVerifier.create(mcpAsyncClient.listPrompts(null)) - .expectErrorMatches(error -> error instanceof McpError - && error.getMessage().equals("Client must be initialized before listing prompts")) - .verify(); + verifyInitializationTimeout(client -> client.listPrompts(null), "listing " + "prompts"); } @Test void testListPrompts() { - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listPrompts(null))) - .consumeNextWith(prompts -> { - assertThat(prompts).isNotNull().satisfies(result -> { - assertThat(result.prompts()).isNotNull(); - - if (!result.prompts().isEmpty()) { - Prompt firstPrompt = result.prompts().get(0); - assertThat(firstPrompt.name()).isNotNull(); - assertThat(firstPrompt.description()).isNotNull(); - } - }); - }) - .verifyComplete(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listPrompts(null))) + .consumeNextWith(prompts -> { + assertThat(prompts).isNotNull().satisfies(result -> { + assertThat(result.prompts()).isNotNull(); + + if (!result.prompts().isEmpty()) { + Prompt firstPrompt = result.prompts().get(0); + assertThat(firstPrompt.name()).isNotNull(); + assertThat(firstPrompt.description()).isNotNull(); + } + }); + }) + .verifyComplete(); + }); } @Test void testGetPromptWithoutInitialization() { GetPromptRequest request = new GetPromptRequest("simple_prompt", Map.of()); - - StepVerifier.create(mcpAsyncClient.getPrompt(request)) - .expectErrorMatches(error -> error instanceof McpError - && error.getMessage().equals("Client must be initialized before getting prompts")) - .verify(); + verifyInitializationTimeout(client -> client.getPrompt(request), "getting " + "prompts"); } @Test void testGetPrompt() { - GetPromptRequest request = new GetPromptRequest("simple_prompt", Map.of()); - - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.getPrompt(request))) - .consumeNextWith(prompt -> { - assertThat(prompt).isNotNull().satisfies(result -> { - assertThat(result.messages()).isNotEmpty(); - assertThat(result.messages()).hasSize(1); - }); - }) - .verifyComplete(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier + .create(mcpAsyncClient.initialize() + .then(mcpAsyncClient.getPrompt(new GetPromptRequest("simple_prompt", Map.of())))) + .consumeNextWith(prompt -> { + assertThat(prompt).isNotNull().satisfies(result -> { + assertThat(result.messages()).isNotEmpty(); + assertThat(result.messages()).hasSize(1); + }); + }) + .verifyComplete(); + }); } @Test void testRootsListChangedWithoutInitialization() { - StepVerifier.create(mcpAsyncClient.rootsListChangedNotification()) - .expectErrorMatches(error -> error instanceof McpError && error.getMessage() - .equals("Client must be initialized before sending roots list changed notification")) - .verify(); + verifyInitializationTimeout(client -> client.rootsListChangedNotification(), + "sending roots list changed notification"); } @Test void testRootsListChanged() { - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.rootsListChangedNotification())) - .verifyComplete(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.rootsListChangedNotification())) + .verifyComplete(); + }); } @Test void testInitializeWithRootsListProviders() { - var transport = createMcpTransport(); - - var client = McpClient.async(transport) - .requestTimeout(getRequestTimeout()) - .roots(new Root("file:///test/path", "test-root")) - .build(); - - StepVerifier.create(client.initialize().then(client.closeGracefully())).verifyComplete(); + withClient(createMcpTransport(), builder -> builder.roots(new Root("file:///test/path", "test-root")), + client -> { + StepVerifier.create(client.initialize().then(client.closeGracefully())).verifyComplete(); + }); } @Test void testAddRoot() { - Root newRoot = new Root("file:///new/test/path", "new-test-root"); - - StepVerifier.create(mcpAsyncClient.addRoot(newRoot)).verifyComplete(); + withClient(createMcpTransport(), mcpAsyncClient -> { + Root newRoot = new Root("file:///new/test/path", "new-test-root"); + StepVerifier.create(mcpAsyncClient.addRoot(newRoot)).verifyComplete(); + }); } @Test void testAddRootWithNullValue() { - StepVerifier.create(mcpAsyncClient.addRoot(null)) - .expectErrorMatches(error -> error.getMessage().contains("Root must not be null")) - .verify(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.addRoot(null)) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class).hasMessage("Root must not be null")) + .verify(); + }); } @Test void testRemoveRoot() { - Root root = new Root("file:///test/path/to/remove", "root-to-remove"); + withClient(createMcpTransport(), mcpAsyncClient -> { + Root root = new Root("file:///test/path/to/remove", "root-to-remove"); + StepVerifier.create(mcpAsyncClient.addRoot(root)).verifyComplete(); - StepVerifier.create(mcpAsyncClient.addRoot(root).then(mcpAsyncClient.removeRoot(root.uri()))).verifyComplete(); + StepVerifier.create(mcpAsyncClient.removeRoot(root.uri())).verifyComplete(); + }); } @Test void testRemoveNonExistentRoot() { - StepVerifier.create(mcpAsyncClient.removeRoot("nonexistent-uri")) - .expectErrorMatches(error -> error.getMessage().contains("Root with uri 'nonexistent-uri' not found")) - .verify(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.removeRoot("nonexistent-uri")) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Root with uri 'nonexistent-uri' not found")) + .verify(); + }); } @Test @Disabled void testReadResource() { - StepVerifier.create(mcpAsyncClient.listResources()).consumeNextWith(resources -> { - if (!resources.resources().isEmpty()) { - Resource firstResource = resources.resources().get(0); - StepVerifier.create(mcpAsyncClient.readResource(firstResource)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.contents()).isNotNull(); - }).verifyComplete(); - } - }).verifyComplete(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.listResources()).consumeNextWith(resources -> { + if (!resources.resources().isEmpty()) { + Resource firstResource = resources.resources().get(0); + StepVerifier.create(mcpAsyncClient.readResource(firstResource)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.contents()).isNotNull(); + }).verifyComplete(); + } + }).verifyComplete(); + }); } @Test void testListResourceTemplatesWithoutInitialization() { - StepVerifier.create(mcpAsyncClient.listResourceTemplates()) - .expectErrorMatches(error -> error instanceof McpError - && error.getMessage().equals("Client must be initialized before listing resource templates")) - .verify(); + verifyInitializationTimeout(client -> client.listResourceTemplates(), "listing resource templates"); } @Test void testListResourceTemplates() { - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResourceTemplates())) - .consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.resourceTemplates()).isNotNull(); - }) - .verifyComplete(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResourceTemplates())) + .consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.resourceTemplates()).isNotNull(); + }) + .verifyComplete(); + }); } // @Test void testResourceSubscription() { - StepVerifier.create(mcpAsyncClient.listResources()).consumeNextWith(resources -> { - if (!resources.resources().isEmpty()) { - Resource firstResource = resources.resources().get(0); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.listResources()).consumeNextWith(resources -> { + if (!resources.resources().isEmpty()) { + Resource firstResource = resources.resources().get(0); - // Test subscribe - StepVerifier.create(mcpAsyncClient.subscribeResource(new SubscribeRequest(firstResource.uri()))) - .verifyComplete(); + // Test subscribe + StepVerifier.create(mcpAsyncClient.subscribeResource(new SubscribeRequest(firstResource.uri()))) + .verifyComplete(); - // Test unsubscribe - StepVerifier.create(mcpAsyncClient.unsubscribeResource(new UnsubscribeRequest(firstResource.uri()))) - .verifyComplete(); - } - }).verifyComplete(); + // Test unsubscribe + StepVerifier.create(mcpAsyncClient.unsubscribeResource(new UnsubscribeRequest(firstResource.uri()))) + .verifyComplete(); + } + }).verifyComplete(); + }); } @Test @@ -354,36 +395,44 @@ void testNotificationHandlers() { AtomicBoolean resourcesNotificationReceived = new AtomicBoolean(false); AtomicBoolean promptsNotificationReceived = new AtomicBoolean(false); - var transport = createMcpTransport(); - var client = McpClient.async(transport) - .requestTimeout(getRequestTimeout()) - .toolsChangeConsumer(tools -> Mono.fromRunnable(() -> toolsNotificationReceived.set(true))) - .resourcesChangeConsumer(resources -> Mono.fromRunnable(() -> resourcesNotificationReceived.set(true))) - .promptsChangeConsumer(prompts -> Mono.fromRunnable(() -> promptsNotificationReceived.set(true))) - .build(); - - StepVerifier.create(client.initialize().then(client.closeGracefully())).verifyComplete(); + withClient(createMcpTransport(), + builder -> builder + .toolsChangeConsumer(tools -> Mono.fromRunnable(() -> toolsNotificationReceived.set(true))) + .resourcesChangeConsumer( + resources -> Mono.fromRunnable(() -> resourcesNotificationReceived.set(true))) + .promptsChangeConsumer(prompts -> Mono.fromRunnable(() -> promptsNotificationReceived.set(true))), + mcpAsyncClient -> { + + var transport = createMcpTransport(); + var client = McpClient.async(transport) + .requestTimeout(getRequestTimeout()) + .toolsChangeConsumer(tools -> Mono.fromRunnable(() -> toolsNotificationReceived.set(true))) + .resourcesChangeConsumer( + resources -> Mono.fromRunnable(() -> resourcesNotificationReceived.set(true))) + .promptsChangeConsumer( + prompts -> Mono.fromRunnable(() -> promptsNotificationReceived.set(true))) + .build(); + + StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete(); + }); } @Test void testInitializeWithSamplingCapability() { - var transport = createMcpTransport(); - - var capabilities = ClientCapabilities.builder().sampling().build(); - - var client = McpClient.async(transport) - .requestTimeout(getRequestTimeout()) - .capabilities(capabilities) - .sampling(request -> Mono.just(CreateMessageResult.builder().message("test").model("test-model").build())) + ClientCapabilities capabilities = ClientCapabilities.builder().sampling().build(); + CreateMessageResult createMessageResult = CreateMessageResult.builder() + .message("test") + .model("test-model") .build(); - - StepVerifier.create(client.initialize().then(client.closeGracefully())).verifyComplete(); + withClient(createMcpTransport(), + builder -> builder.capabilities(capabilities).sampling(request -> Mono.just(createMessageResult)), + client -> { + StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete(); + }); } @Test void testInitializeWithAllCapabilities() { - var transport = createMcpTransport(); - var capabilities = ClientCapabilities.builder() .experimental(Map.of("feature", "test")) .roots(true) @@ -392,18 +441,14 @@ void testInitializeWithAllCapabilities() { Function> samplingHandler = request -> Mono .just(CreateMessageResult.builder().message("test").model("test-model").build()); - var client = McpClient.async(transport) - .requestTimeout(getRequestTimeout()) - .capabilities(capabilities) - .sampling(samplingHandler) - .build(); - StepVerifier.create(client.initialize()).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.capabilities()).isNotNull(); - }).verifyComplete(); + withClient(createMcpTransport(), builder -> builder.capabilities(capabilities).sampling(samplingHandler), + client -> - StepVerifier.create(client.closeGracefully()).verifyComplete(); + StepVerifier.create(client.initialize()).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.capabilities()).isNotNull(); + }).verifyComplete()); } // --------------------------------------- @@ -412,43 +457,46 @@ void testInitializeWithAllCapabilities() { @Test void testLoggingLevelsWithoutInitialization() { - StepVerifier.create(mcpAsyncClient.setLoggingLevel(McpSchema.LoggingLevel.DEBUG)) - .expectErrorMatches(error -> error instanceof McpError - && error.getMessage().equals("Client must be initialized before setting logging level")) - .verify(); + verifyInitializationTimeout(client -> client.setLoggingLevel(McpSchema.LoggingLevel.DEBUG), + "setting logging level"); } @Test void testLoggingLevels() { - Mono testAllLevels = mcpAsyncClient.initialize().then(Mono.defer(() -> { - Mono chain = Mono.empty(); - for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { - chain = chain.then(mcpAsyncClient.setLoggingLevel(level)); - } - return chain; - })); + withClient(createMcpTransport(), mcpAsyncClient -> { + Mono testAllLevels = mcpAsyncClient.initialize().then(Mono.defer(() -> { + Mono chain = Mono.empty(); + for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { + chain = chain.then(mcpAsyncClient.setLoggingLevel(level)); + } + return chain; + })); - StepVerifier.create(testAllLevels).verifyComplete(); + StepVerifier.create(testAllLevels).verifyComplete(); + }); } @Test void testLoggingConsumer() { AtomicBoolean logReceived = new AtomicBoolean(false); - var transport = createMcpTransport(); - var client = McpClient.async(transport) - .requestTimeout(getRequestTimeout()) - .loggingConsumer(notification -> Mono.fromRunnable(() -> logReceived.set(true))) - .build(); + withClient(createMcpTransport(), + builder -> builder.loggingConsumer(notification -> Mono.fromRunnable(() -> logReceived.set(true))), + client -> { + StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete(); + StepVerifier.create(client.closeGracefully()).verifyComplete(); + + }); - StepVerifier.create(client.initialize().then(client.closeGracefully())).verifyComplete(); } @Test void testLoggingWithNullNotification() { - StepVerifier.create(mcpAsyncClient.setLoggingLevel(null)) - .expectErrorMatches(error -> error.getMessage().contains("Logging level must not be null")) - .verify(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.setLoggingLevel(null)) + .expectErrorMatches(error -> error.getMessage().contains("Logging level must not be null")) + .verify(); + }); } } diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java index 52a0138f6..1c042bf24 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java @@ -7,6 +7,9 @@ import java.time.Duration; import java.util.Map; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.function.Function; import io.modelcontextprotocol.spec.ClientMcpTransport; import io.modelcontextprotocol.spec.McpError; @@ -27,6 +30,10 @@ import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Scheduler; +import reactor.core.scheduler.Schedulers; +import reactor.test.StepVerifier; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; @@ -41,12 +48,8 @@ // KEEP IN SYNC with the class in mcp-test module public abstract class AbstractMcpSyncClientTests { - private McpSyncClient mcpSyncClient; - private static final String TEST_MESSAGE = "Hello MCP Spring AI!"; - protected ClientMcpTransport mcpTransport; - abstract protected ClientMcpTransport createMcpTransport(); protected void onStart() { @@ -63,254 +66,322 @@ protected Duration getInitializationTimeout() { return Duration.ofSeconds(2); } - @BeforeEach - void setUp() { - onStart(); - this.mcpTransport = createMcpTransport(); + McpSyncClient client(ClientMcpTransport transport) { + return client(transport, Function.identity()); + } + + McpSyncClient client(ClientMcpTransport transport, Function customizer) { + AtomicReference client = new AtomicReference<>(); assertThatCode(() -> { - mcpSyncClient = McpClient.sync(mcpTransport) + McpClient.SyncSpec builder = McpClient.sync(transport) .requestTimeout(getRequestTimeout()) .initializationTimeout(getInitializationTimeout()) - .capabilities(ClientCapabilities.builder().roots(true).build()) - .build(); + .capabilities(ClientCapabilities.builder().roots(true).build()); + builder = customizer.apply(builder); + client.set(builder.build()); }).doesNotThrowAnyException(); + + return client.get(); + } + + void withClient(ClientMcpTransport transport, Consumer c) { + withClient(transport, Function.identity(), c); + } + + void withClient(ClientMcpTransport transport, Function customizer, + Consumer c) { + var client = client(transport, customizer); + try { + c.accept(client); + } + finally { + assertThat(client.closeGracefully()).isTrue(); + } + } + + @BeforeEach + void setUp() { + onStart(); + } @AfterEach void tearDown() { - if (mcpSyncClient != null) { - assertThatCode(() -> mcpSyncClient.close()).doesNotThrowAnyException(); - } onClose(); } + static final Object DUMMY_RETURN_VALUE = new Object(); + + void verifyNotificationTimesOut(Consumer operation, String action) { + verifyCallTimesOut(client -> { + operation.accept(client); + return DUMMY_RETURN_VALUE; + }, action); + } + + void verifyCallTimesOut(Function operation, String action) { + withClient(createMcpTransport(), mcpSyncClient -> { + // This scheduler is not replaced by virtual time scheduler + Scheduler customScheduler = Schedulers.newBoundedElastic(1, 1, "actualBoundedElastic"); + + StepVerifier.withVirtualTime(() -> Mono.fromSupplier(() -> operation.apply(mcpSyncClient)) + // offload the blocking call to the real scheduler + .subscribeOn(customScheduler)) + .expectSubscription() + .thenAwait(getInitializationTimeout()) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be initialized before " + action)) + .verify(); + + customScheduler.dispose(); + }); + } + @Test void testConstructorWithInvalidArguments() { assertThatThrownBy(() -> McpClient.sync(null).build()).isInstanceOf(IllegalArgumentException.class) .hasMessage("Transport must not be null"); - assertThatThrownBy(() -> McpClient.sync(mcpTransport).requestTimeout(null).build()) + assertThatThrownBy(() -> McpClient.sync(createMcpTransport()).requestTimeout(null).build()) .isInstanceOf(IllegalArgumentException.class) .hasMessage("Request timeout must not be null"); } @Test void testListToolsWithoutInitialization() { - assertThatThrownBy(() -> mcpSyncClient.listTools(null)).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before listing tools"); + verifyCallTimesOut(client -> client.listTools(null), "listing tools"); } @Test void testListTools() { - mcpSyncClient.initialize(); - ListToolsResult tools = mcpSyncClient.listTools(null); + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + ListToolsResult tools = mcpSyncClient.listTools(null); - assertThat(tools).isNotNull().satisfies(result -> { - assertThat(result.tools()).isNotNull().isNotEmpty(); + assertThat(tools).isNotNull().satisfies(result -> { + assertThat(result.tools()).isNotNull().isNotEmpty(); - Tool firstTool = result.tools().get(0); - assertThat(firstTool.name()).isNotNull(); - assertThat(firstTool.description()).isNotNull(); + Tool firstTool = result.tools().get(0); + assertThat(firstTool.name()).isNotNull(); + assertThat(firstTool.description()).isNotNull(); + }); }); } @Test void testCallToolsWithoutInitialization() { - assertThatThrownBy(() -> mcpSyncClient.callTool(new CallToolRequest("add", Map.of("a", 3, "b", 4)))) - .isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before calling tools"); + verifyCallTimesOut(client -> client.callTool(new CallToolRequest("add", Map.of("a", 3, "b", 4))), + "calling tools"); } @Test void testCallTools() { - mcpSyncClient.initialize(); - CallToolResult toolResult = mcpSyncClient.callTool(new CallToolRequest("add", Map.of("a", 3, "b", 4))); + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + CallToolResult toolResult = mcpSyncClient.callTool(new CallToolRequest("add", Map.of("a", 3, "b", 4))); - assertThat(toolResult).isNotNull().satisfies(result -> { + assertThat(toolResult).isNotNull().satisfies(result -> { - assertThat(result.content()).hasSize(1); + assertThat(result.content()).hasSize(1); - TextContent content = (TextContent) result.content().get(0); + TextContent content = (TextContent) result.content().get(0); - assertThat(content).isNotNull(); - assertThat(content.text()).isNotNull(); - assertThat(content.text()).contains("7"); + assertThat(content).isNotNull(); + assertThat(content.text()).isNotNull(); + assertThat(content.text()).contains("7"); + }); }); } @Test void testPingWithoutInitialization() { - assertThatThrownBy(() -> mcpSyncClient.ping()).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before pinging the server"); + verifyCallTimesOut(client -> client.ping(), "pinging the server"); } @Test void testPing() { - mcpSyncClient.initialize(); - assertThatCode(() -> mcpSyncClient.ping()).doesNotThrowAnyException(); + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + assertThatCode(() -> mcpSyncClient.ping()).doesNotThrowAnyException(); + }); } @Test void testCallToolWithoutInitialization() { CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", TEST_MESSAGE)); - - assertThatThrownBy(() -> mcpSyncClient.callTool(callToolRequest)).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before calling tools"); + verifyCallTimesOut(client -> client.callTool(callToolRequest), "calling tools"); } @Test void testCallTool() { - mcpSyncClient.initialize(); - CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", TEST_MESSAGE)); + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", TEST_MESSAGE)); - CallToolResult callToolResult = mcpSyncClient.callTool(callToolRequest); + CallToolResult callToolResult = mcpSyncClient.callTool(callToolRequest); - assertThat(callToolResult).isNotNull().satisfies(result -> { - assertThat(result.content()).isNotNull(); - assertThat(result.isError()).isNull(); + assertThat(callToolResult).isNotNull().satisfies(result -> { + assertThat(result.content()).isNotNull(); + assertThat(result.isError()).isNull(); + }); }); } @Test void testCallToolWithInvalidTool() { - CallToolRequest invalidRequest = new CallToolRequest("nonexistent_tool", Map.of("message", TEST_MESSAGE)); + withClient(createMcpTransport(), mcpSyncClient -> { + CallToolRequest invalidRequest = new CallToolRequest("nonexistent_tool", Map.of("message", TEST_MESSAGE)); - assertThatThrownBy(() -> mcpSyncClient.callTool(invalidRequest)).isInstanceOf(Exception.class); + assertThatThrownBy(() -> mcpSyncClient.callTool(invalidRequest)).isInstanceOf(Exception.class); + }); } @Test void testRootsListChangedWithoutInitialization() { - assertThatThrownBy(() -> mcpSyncClient.rootsListChangedNotification()).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before sending roots list changed notification"); + verifyNotificationTimesOut(client -> client.rootsListChangedNotification(), + "sending roots list changed notification"); } @Test void testRootsListChanged() { - mcpSyncClient.initialize(); - assertThatCode(() -> mcpSyncClient.rootsListChangedNotification()).doesNotThrowAnyException(); + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + assertThatCode(() -> mcpSyncClient.rootsListChangedNotification()).doesNotThrowAnyException(); + }); } @Test void testListResourcesWithoutInitialization() { - assertThatThrownBy(() -> mcpSyncClient.listResources(null)).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before listing resources"); + verifyCallTimesOut(client -> client.listResources(null), "listing resources"); } @Test void testListResources() { - mcpSyncClient.initialize(); - ListResourcesResult resources = mcpSyncClient.listResources(null); - - assertThat(resources).isNotNull().satisfies(result -> { - assertThat(result.resources()).isNotNull(); - - if (!result.resources().isEmpty()) { - Resource firstResource = result.resources().get(0); - assertThat(firstResource.uri()).isNotNull(); - assertThat(firstResource.name()).isNotNull(); - } + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + ListResourcesResult resources = mcpSyncClient.listResources(null); + + assertThat(resources).isNotNull().satisfies(result -> { + assertThat(result.resources()).isNotNull(); + + if (!result.resources().isEmpty()) { + Resource firstResource = result.resources().get(0); + assertThat(firstResource.uri()).isNotNull(); + assertThat(firstResource.name()).isNotNull(); + } + }); }); } @Test void testClientSessionState() { - assertThat(mcpSyncClient).isNotNull(); + withClient(createMcpTransport(), mcpSyncClient -> { + assertThat(mcpSyncClient).isNotNull(); + }); } @Test void testInitializeWithRootsListProviders() { - var transport = createMcpTransport(); - - var client = McpClient.sync(transport) - .requestTimeout(getRequestTimeout()) - .roots(new Root("file:///test/path", "test-root")) - .build(); + withClient(createMcpTransport(), builder -> builder.roots(new Root("file:///test/path", "test-root")), + mcpSyncClient -> { - assertThatCode(() -> { - client.initialize(); - client.close(); - }).doesNotThrowAnyException(); + assertThatCode(() -> { + mcpSyncClient.initialize(); + mcpSyncClient.close(); + }).doesNotThrowAnyException(); + }); } @Test void testAddRoot() { - Root newRoot = new Root("file:///new/test/path", "new-test-root"); - assertThatCode(() -> mcpSyncClient.addRoot(newRoot)).doesNotThrowAnyException(); + withClient(createMcpTransport(), mcpSyncClient -> { + Root newRoot = new Root("file:///new/test/path", "new-test-root"); + assertThatCode(() -> mcpSyncClient.addRoot(newRoot)).doesNotThrowAnyException(); + }); } @Test void testAddRootWithNullValue() { - assertThatThrownBy(() -> mcpSyncClient.addRoot(null)).hasMessageContaining("Root must not be null"); + withClient(createMcpTransport(), mcpSyncClient -> { + assertThatThrownBy(() -> mcpSyncClient.addRoot(null)).hasMessageContaining("Root must not be null"); + }); } @Test void testRemoveRoot() { - Root root = new Root("file:///test/path/to/remove", "root-to-remove"); - assertThatCode(() -> { - mcpSyncClient.addRoot(root); - mcpSyncClient.removeRoot(root.uri()); - }).doesNotThrowAnyException(); + withClient(createMcpTransport(), mcpSyncClient -> { + Root root = new Root("file:///test/path/to/remove", "root-to-remove"); + assertThatCode(() -> { + mcpSyncClient.addRoot(root); + mcpSyncClient.removeRoot(root.uri()); + }).doesNotThrowAnyException(); + }); } @Test void testRemoveNonExistentRoot() { - assertThatThrownBy(() -> mcpSyncClient.removeRoot("nonexistent-uri")) - .hasMessageContaining("Root with uri 'nonexistent-uri' not found"); + withClient(createMcpTransport(), mcpSyncClient -> { + assertThatThrownBy(() -> mcpSyncClient.removeRoot("nonexistent-uri")) + .hasMessageContaining("Root with uri 'nonexistent-uri' not found"); + }); } @Test void testReadResourceWithoutInitialization() { - assertThatThrownBy(() -> { - Resource resource = new Resource("test://uri", "Test Resource", null, null, null); - mcpSyncClient.readResource(resource); - }).isInstanceOf(McpError.class).hasMessage("Client must be initialized before reading resources"); + Resource resource = new Resource("test://uri", "Test Resource", null, null, null); + verifyCallTimesOut(client -> client.readResource(resource), "reading resources"); } @Test void testReadResource() { - mcpSyncClient.initialize(); - ListResourcesResult resources = mcpSyncClient.listResources(null); + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + ListResourcesResult resources = mcpSyncClient.listResources(null); - if (!resources.resources().isEmpty()) { - Resource firstResource = resources.resources().get(0); - ReadResourceResult result = mcpSyncClient.readResource(firstResource); + if (!resources.resources().isEmpty()) { + Resource firstResource = resources.resources().get(0); + ReadResourceResult result = mcpSyncClient.readResource(firstResource); - assertThat(result).isNotNull(); - assertThat(result.contents()).isNotNull(); - } + assertThat(result).isNotNull(); + assertThat(result.contents()).isNotNull(); + } + }); } @Test void testListResourceTemplatesWithoutInitialization() { - assertThatThrownBy(() -> mcpSyncClient.listResourceTemplates(null)).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before listing resource templates"); + verifyCallTimesOut(client -> client.listResourceTemplates(null), "listing resource templates"); } @Test void testListResourceTemplates() { - mcpSyncClient.initialize(); - ListResourceTemplatesResult result = mcpSyncClient.listResourceTemplates(null); + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + ListResourceTemplatesResult result = mcpSyncClient.listResourceTemplates(null); - assertThat(result).isNotNull(); - assertThat(result.resourceTemplates()).isNotNull(); + assertThat(result).isNotNull(); + assertThat(result.resourceTemplates()).isNotNull(); + }); } // @Test void testResourceSubscription() { - ListResourcesResult resources = mcpSyncClient.listResources(null); + withClient(createMcpTransport(), mcpSyncClient -> { + ListResourcesResult resources = mcpSyncClient.listResources(null); - if (!resources.resources().isEmpty()) { - Resource firstResource = resources.resources().get(0); + if (!resources.resources().isEmpty()) { + Resource firstResource = resources.resources().get(0); - // Test subscribe - assertThatCode(() -> mcpSyncClient.subscribeResource(new SubscribeRequest(firstResource.uri()))) - .doesNotThrowAnyException(); + // Test subscribe + assertThatCode(() -> mcpSyncClient.subscribeResource(new SubscribeRequest(firstResource.uri()))) + .doesNotThrowAnyException(); - // Test unsubscribe - assertThatCode(() -> mcpSyncClient.unsubscribeResource(new UnsubscribeRequest(firstResource.uri()))) - .doesNotThrowAnyException(); - } + // Test unsubscribe + assertThatCode(() -> mcpSyncClient.unsubscribeResource(new UnsubscribeRequest(firstResource.uri()))) + .doesNotThrowAnyException(); + } + }); } @Test @@ -319,18 +390,17 @@ void testNotificationHandlers() { AtomicBoolean resourcesNotificationReceived = new AtomicBoolean(false); AtomicBoolean promptsNotificationReceived = new AtomicBoolean(false); - var transport = createMcpTransport(); - var client = McpClient.sync(transport) - .requestTimeout(getRequestTimeout()) - .toolsChangeConsumer(tools -> toolsNotificationReceived.set(true)) - .resourcesChangeConsumer(resources -> resourcesNotificationReceived.set(true)) - .promptsChangeConsumer(prompts -> promptsNotificationReceived.set(true)) - .build(); + withClient(createMcpTransport(), + builder -> builder.toolsChangeConsumer(tools -> toolsNotificationReceived.set(true)) + .resourcesChangeConsumer(resources -> resourcesNotificationReceived.set(true)) + .promptsChangeConsumer(prompts -> promptsNotificationReceived.set(true)), + client -> { - assertThatCode(() -> { - client.initialize(); - client.close(); - }).doesNotThrowAnyException(); + assertThatCode(() -> { + client.initialize(); + client.close(); + }).doesNotThrowAnyException(); + }); } // --------------------------------------- @@ -339,40 +409,37 @@ void testNotificationHandlers() { @Test void testLoggingLevelsWithoutInitialization() { - assertThatThrownBy(() -> mcpSyncClient.setLoggingLevel(McpSchema.LoggingLevel.DEBUG)) - .isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before setting logging level"); + verifyNotificationTimesOut(client -> client.setLoggingLevel(McpSchema.LoggingLevel.DEBUG), + "setting logging level"); } @Test void testLoggingLevels() { - mcpSyncClient.initialize(); - // Test all logging levels - for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { - assertThatCode(() -> mcpSyncClient.setLoggingLevel(level)).doesNotThrowAnyException(); - } + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + // Test all logging levels + for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { + assertThatCode(() -> mcpSyncClient.setLoggingLevel(level)).doesNotThrowAnyException(); + } + }); } @Test void testLoggingConsumer() { AtomicBoolean logReceived = new AtomicBoolean(false); - var transport = createMcpTransport(); - - var client = McpClient.sync(transport) - .requestTimeout(getRequestTimeout()) - .loggingConsumer(notification -> logReceived.set(true)) - .build(); - - assertThatCode(() -> { - client.initialize(); - client.close(); - }).doesNotThrowAnyException(); + withClient(createMcpTransport(), builder -> builder.requestTimeout(getRequestTimeout()) + .loggingConsumer(notification -> logReceived.set(true)), client -> { + assertThatCode(() -> { + client.initialize(); + client.close(); + }).doesNotThrowAnyException(); + }); } @Test void testLoggingWithNullNotification() { - assertThatThrownBy(() -> mcpSyncClient.setLoggingLevel(null)) - .hasMessageContaining("Logging level must not be null"); + withClient(createMcpTransport(), mcpSyncClient -> assertThatThrownBy(() -> mcpSyncClient.setLoggingLevel(null)) + .hasMessageContaining("Logging level must not be null")); } } diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java index 6d759b4ba..ebf10b9a3 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java @@ -5,6 +5,8 @@ package io.modelcontextprotocol.client; import java.time.Duration; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import io.modelcontextprotocol.client.transport.ServerParameters; @@ -13,6 +15,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; import reactor.core.publisher.Sinks; +import reactor.test.StepVerifier; import static org.assertj.core.api.Assertions.assertThat; @@ -35,15 +38,26 @@ protected ClientMcpTransport createMcpTransport() { } @Test - void customErrorHandlerShouldReceiveErrors() { + void customErrorHandlerShouldReceiveErrors() throws InterruptedException { + CountDownLatch latch = new CountDownLatch(1); AtomicReference receivedError = new AtomicReference<>(); - ((StdioClientTransport) mcpTransport).setStdErrorHandler(error -> receivedError.set(error)); + ClientMcpTransport transport = createMcpTransport(); + StepVerifier.create(transport.connect(msg -> msg)).verifyComplete(); + + ((StdioClientTransport) transport).setStdErrorHandler(error -> { + receivedError.set(error); + latch.countDown(); + }); String errorMessage = "Test error"; - ((StdioClientTransport) mcpTransport).getErrorSink().emitNext(errorMessage, Sinks.EmitFailureHandler.FAIL_FAST); + ((StdioClientTransport) transport).getErrorSink().emitNext(errorMessage, Sinks.EmitFailureHandler.FAIL_FAST); + + assertThat(latch.await(5, TimeUnit.SECONDS)).isTrue(); assertThat(receivedError.get()).isNotNull().isEqualTo(errorMessage); + + StepVerifier.create(transport.closeGracefully()).expectComplete().verify(Duration.ofSeconds(5)); } protected Duration getInitializationTimeout() { From e996a5de1e72c11031d6a552af13f9d5dd95a427 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Thu, 20 Mar 2025 09:03:08 +0100 Subject: [PATCH 008/222] Follow-up fix client tests reliability MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Dariusz Jędrzejczyk --- .../client/AbstractMcpAsyncClientTests.java | 15 +++------------ .../client/AbstractMcpSyncClientTests.java | 11 ++++++++--- .../client/AbstractMcpAsyncClientTests.java | 15 +++------------ .../client/AbstractMcpSyncClientTests.java | 11 ++++++++--- 4 files changed, 22 insertions(+), 30 deletions(-) diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java index 033139adc..18ec06c63 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java @@ -446,18 +446,9 @@ void testNotificationHandlers() { resources -> Mono.fromRunnable(() -> resourcesNotificationReceived.set(true))) .promptsChangeConsumer(prompts -> Mono.fromRunnable(() -> promptsNotificationReceived.set(true))), mcpAsyncClient -> { - - var transport = createMcpTransport(); - var client = McpClient.async(transport) - .requestTimeout(getRequestTimeout()) - .toolsChangeConsumer(tools -> Mono.fromRunnable(() -> toolsNotificationReceived.set(true))) - .resourcesChangeConsumer( - resources -> Mono.fromRunnable(() -> resourcesNotificationReceived.set(true))) - .promptsChangeConsumer( - prompts -> Mono.fromRunnable(() -> promptsNotificationReceived.set(true))) - .build(); - - StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete(); + StepVerifier.create(mcpAsyncClient.initialize()) + .expectNextMatches(Objects::nonNull) + .verifyComplete(); }); } diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java index 032f8684a..191de23b5 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java @@ -119,15 +119,20 @@ void verifyNotificationTimesOut(Consumer operation, String ac }, action); } - void verifyCallTimesOut(Function operation, String action) { + void verifyCallTimesOut(Function blockingOperation, String action) { withClient(createMcpTransport(), mcpSyncClient -> { // This scheduler is not replaced by virtual time scheduler Scheduler customScheduler = Schedulers.newBoundedElastic(1, 1, "actualBoundedElastic"); - StepVerifier.withVirtualTime(() -> Mono.fromSupplier(() -> operation.apply(mcpSyncClient)) - // offload the blocking call to the real scheduler + StepVerifier.withVirtualTime(() -> Mono.fromSupplier(() -> blockingOperation.apply(mcpSyncClient)) + // Offload the blocking call to the real scheduler .subscribeOn(customScheduler)) .expectSubscription() + // This works without actually waiting but executes all the + // tasks pending execution on the VirtualTimeScheduler. + // It is possible to execute the blocking code from the operation + // because it is blocked on a dedicated Scheduler and the main + // flow is not blocked and uses the VirtualTimeScheduler. .thenAwait(getInitializationTimeout()) .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) .hasMessage("Client must be initialized before " + action)) diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java index 720388545..06a231edd 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java @@ -402,18 +402,9 @@ void testNotificationHandlers() { resources -> Mono.fromRunnable(() -> resourcesNotificationReceived.set(true))) .promptsChangeConsumer(prompts -> Mono.fromRunnable(() -> promptsNotificationReceived.set(true))), mcpAsyncClient -> { - - var transport = createMcpTransport(); - var client = McpClient.async(transport) - .requestTimeout(getRequestTimeout()) - .toolsChangeConsumer(tools -> Mono.fromRunnable(() -> toolsNotificationReceived.set(true))) - .resourcesChangeConsumer( - resources -> Mono.fromRunnable(() -> resourcesNotificationReceived.set(true))) - .promptsChangeConsumer( - prompts -> Mono.fromRunnable(() -> promptsNotificationReceived.set(true))) - .build(); - - StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete(); + StepVerifier.create(mcpAsyncClient.initialize()) + .expectNextMatches(Objects::nonNull) + .verifyComplete(); }); } diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java index 1c042bf24..f4d8dbdbc 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java @@ -120,15 +120,20 @@ void verifyNotificationTimesOut(Consumer operation, String ac }, action); } - void verifyCallTimesOut(Function operation, String action) { + void verifyCallTimesOut(Function blockingOperation, String action) { withClient(createMcpTransport(), mcpSyncClient -> { // This scheduler is not replaced by virtual time scheduler Scheduler customScheduler = Schedulers.newBoundedElastic(1, 1, "actualBoundedElastic"); - StepVerifier.withVirtualTime(() -> Mono.fromSupplier(() -> operation.apply(mcpSyncClient)) - // offload the blocking call to the real scheduler + StepVerifier.withVirtualTime(() -> Mono.fromSupplier(() -> blockingOperation.apply(mcpSyncClient)) + // Offload the blocking call to the real scheduler .subscribeOn(customScheduler)) .expectSubscription() + // This works without actually waiting but executes all the + // tasks pending execution on the VirtualTimeScheduler. + // It is possible to execute the blocking code from the operation + // because it is blocked on a dedicated Scheduler and the main + // flow is not blocked and uses the VirtualTimeScheduler. .thenAwait(getInitializationTimeout()) .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) .hasMessage("Client must be initialized before " + action)) From 9c2b836e414ff11f3a57e925a85c433114df2a02 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Thu, 20 Mar 2025 09:36:56 +0100 Subject: [PATCH 009/222] Sync async client tests between mcp and mcp-test module MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Dariusz Jędrzejczyk --- .../client/AbstractMcpAsyncClientTests.java | 100 +++++------------- .../client/AbstractMcpAsyncClientTests.java | 1 - 2 files changed, 24 insertions(+), 77 deletions(-) diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java index 18ec06c63..02aa23d8a 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java @@ -109,6 +109,17 @@ void tearDown() { onClose(); } + void verifyInitializationTimeout(Function> operation, String action) { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.withVirtualTime(() -> operation.apply(mcpAsyncClient)) + .expectSubscription() + .thenAwait(getInitializationTimeout()) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be initialized before " + action)) + .verify(); + }); + } + @Test void testConstructorWithInvalidArguments() { assertThatThrownBy(() -> McpClient.async(null).build()).isInstanceOf(IllegalArgumentException.class) @@ -121,14 +132,7 @@ void testConstructorWithInvalidArguments() { @Test void testListToolsWithoutInitialization() { - withClient(createMcpTransport(), mcpAsyncClient -> { - StepVerifier.withVirtualTime(() -> mcpAsyncClient.listTools(null)) - .expectSubscription() - .thenAwait(getInitializationTimeout()) - .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before listing tools")) - .verify(); - }); + verifyInitializationTimeout(client -> client.listTools(null), "listing tools"); } @Test @@ -148,14 +152,7 @@ void testListTools() { @Test void testPingWithoutInitialization() { - withClient(createMcpTransport(), mcpAsyncClient -> { - StepVerifier.withVirtualTime(() -> mcpAsyncClient.ping()) - .expectSubscription() - .thenAwait(getInitializationTimeout()) - .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before pinging the " + "server")) - .verify(); - }); + verifyInitializationTimeout(client -> client.ping(), "pinging the server"); } @Test @@ -169,16 +166,8 @@ void testPing() { @Test void testCallToolWithoutInitialization() { - withClient(createMcpTransport(), mcpAsyncClient -> { - CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", ECHO_TEST_MESSAGE)); - - StepVerifier.withVirtualTime(() -> mcpAsyncClient.callTool(callToolRequest)) - .expectSubscription() - .thenAwait(getInitializationTimeout()) - .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before calling tools")) - .verify(); - }); + CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", ECHO_TEST_MESSAGE)); + verifyInitializationTimeout(client -> client.callTool(callToolRequest), "calling tools"); } @Test @@ -212,14 +201,7 @@ void testCallToolWithInvalidTool() { @Test void testListResourcesWithoutInitialization() { - withClient(createMcpTransport(), mcpAsyncClient -> { - StepVerifier.withVirtualTime(() -> mcpAsyncClient.listResources(null)) - .expectSubscription() - .thenAwait(getInitializationTimeout()) - .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before listing resources")) - .verify(); - }); + verifyInitializationTimeout(client -> client.listResources(null), "listing resources"); } @Test @@ -250,14 +232,7 @@ void testMcpAsyncClientState() { @Test void testListPromptsWithoutInitialization() { - withClient(createMcpTransport(), mcpAsyncClient -> { - StepVerifier.withVirtualTime(() -> mcpAsyncClient.listPrompts(null)) - .expectSubscription() - .thenAwait(getInitializationTimeout()) - .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before listing prompts")) - .verify(); - }); + verifyInitializationTimeout(client -> client.listPrompts(null), "listing " + "prompts"); } @Test @@ -281,16 +256,8 @@ void testListPrompts() { @Test void testGetPromptWithoutInitialization() { - withClient(createMcpTransport(), mcpAsyncClient -> { - GetPromptRequest request = new GetPromptRequest("simple_prompt", Map.of()); - - StepVerifier.withVirtualTime(() -> mcpAsyncClient.getPrompt(request)) - .expectSubscription() - .thenAwait(getInitializationTimeout()) - .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before getting prompts")) - .verify(); - }); + GetPromptRequest request = new GetPromptRequest("simple_prompt", Map.of()); + verifyInitializationTimeout(client -> client.getPrompt(request), "getting " + "prompts"); } @Test @@ -311,14 +278,8 @@ void testGetPrompt() { @Test void testRootsListChangedWithoutInitialization() { - withClient(createMcpTransport(), mcpAsyncClient -> { - StepVerifier.withVirtualTime(() -> mcpAsyncClient.rootsListChangedNotification()) - .expectSubscription() - .thenAwait(getInitializationTimeout()) - .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before sending roots list changed notification")) - .verify(); - }); + verifyInitializationTimeout(client -> client.rootsListChangedNotification(), + "sending roots list changed notification"); } @Test @@ -392,14 +353,7 @@ void testReadResource() { @Test void testListResourceTemplatesWithoutInitialization() { - withClient(createMcpTransport(), mcpAsyncClient -> { - StepVerifier.withVirtualTime(() -> mcpAsyncClient.listResourceTemplates()) - .expectSubscription() - .thenAwait(getInitializationTimeout()) - .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before listing resource templates")) - .verify(); - }); + verifyInitializationTimeout(client -> client.listResourceTemplates(), "listing resource templates"); } @Test @@ -492,14 +446,8 @@ void testInitializeWithAllCapabilities() { @Test void testLoggingLevelsWithoutInitialization() { - withClient(createMcpTransport(), - mcpAsyncClient -> StepVerifier - .withVirtualTime(() -> mcpAsyncClient.setLoggingLevel(McpSchema.LoggingLevel.DEBUG)) - .expectSubscription() - .thenAwait(getInitializationTimeout()) - .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before setting logging level")) - .verify()); + verifyInitializationTimeout(client -> client.setLoggingLevel(McpSchema.LoggingLevel.DEBUG), + "setting logging level"); } @Test diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java index 06a231edd..f7a0a4924 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java @@ -11,7 +11,6 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import java.util.function.Function; -import java.util.function.Supplier; import io.modelcontextprotocol.spec.ClientMcpTransport; import io.modelcontextprotocol.spec.McpError; From 34a733509d88054b64ff84eaf43d0fb1f47bd1be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Thu, 20 Mar 2025 18:14:57 +0100 Subject: [PATCH 010/222] refactor: introduce session-based architecture for MCP server (#31) This commit introduces a major refactoring of the MCP Java SDK to implement a session-based architecture for server-side implementations. The changes improve the SDK's ability to handle multiple concurrent client connections and provide an API better aligned with the MCP specification. Key changes: - Introduce McpServerTransportProvider interface to manage client connections - Rename ClientMcpTransport to McpClientTransport and ServerMcpTransport to McpServerTransport - Add exchange objects (McpAsyncServerExchange, McpSyncServerExchange) for client interaction - Update handler signatures to include exchange parameter: (args) -> result to (exchange, args) -> result - Rename Registration classes to Specification classes - Update method names (e.g., rootsChangeConsumers to rootsChangeHandlers) - Deprecate old interfaces and classes for removal in 0.9.0 - Add migration guide (migration-0.8.0.md) Resolves #9, #15 Co-authored-by: Christian Tzolov Signed-off-by: Christian Tzolov --- .../transport/WebFluxSseClientTransport.java | 4 +- .../transport/WebFluxSseServerTransport.java | 17 +- .../WebFluxSseServerTransportProvider.java | 351 ++++ .../WebFluxSseIntegrationTests.java | 208 ++- .../client/WebFluxSseMcpAsyncClientTests.java | 4 +- .../client/WebFluxSseMcpSyncClientTests.java | 4 +- ...bFluxSseMcpAsyncServerDeprecatedTests.java | 55 + .../server/WebFluxSseMcpAsyncServerTests.java | 15 +- ...ebFluxSseMcpSyncServerDeprecatecTests.java | 55 + .../server/WebFluxSseMcpSyncServerTests.java | 16 +- .../legacy/WebFluxSseIntegrationTests.java | 459 +++++ .../transport/WebMvcSseServerTransport.java | 5 +- .../WebMvcSseServerTransportProvider.java | 399 ++++ ...seAsyncServerTransportDeprecatedTests.java | 118 ++ .../WebMvcSseAsyncServerTransportTests.java | 24 +- .../WebMvcSseIntegrationDeprecatedTests.java | 508 +++++ .../server/WebMvcSseIntegrationTests.java | 195 +- ...SseSyncServerTransportDeprecatedTests.java | 118 ++ .../WebMvcSseSyncServerTransportTests.java | 23 +- .../MockMcpTransport.java | 8 +- .../client/AbstractMcpAsyncClientTests.java | 12 +- .../client/AbstractMcpSyncClientTests.java | 12 +- ...AbstractMcpAsyncServerDeprecatedTests.java | 465 +++++ .../server/AbstractMcpAsyncServerTests.java | 127 +- .../AbstractMcpSyncServerDeprecatedTests.java | 431 +++++ .../server/AbstractMcpSyncServerTests.java | 131 +- .../client/McpAsyncClient.java | 12 +- .../client/McpClient.java | 41 + .../client/McpSyncClient.java | 5 +- .../HttpClientSseClientTransport.java | 28 +- .../transport/StdioClientTransport.java | 5 +- .../server/McpAsyncServer.java | 1641 +++++++++++++---- .../server/McpAsyncServerExchange.java | 104 ++ .../server/McpServer.java | 1003 +++++++++- .../server/McpServerFeatures.java | 360 +++- .../server/McpSyncServer.java | 48 + .../server/McpSyncServerExchange.java | 78 + .../HttpServletSseServerTransport.java | 5 +- ...HttpServletSseServerTransportProvider.java | 432 +++++ .../transport/StdioServerTransport.java | 3 + .../StdioServerTransportProvider.java | 306 +++ .../spec/ClientMcpTransport.java | 2 + .../spec/DefaultMcpSession.java | 3 + .../spec/McpClientSession.java | 288 +++ .../spec/McpClientTransport.java | 21 + .../spec/McpServerSession.java | 354 ++++ .../spec/McpServerTransport.java | 11 + .../spec/McpServerTransportProvider.java | 66 + .../modelcontextprotocol/spec/McpSession.java | 15 +- .../spec/McpTransport.java | 9 +- .../spec/ServerMcpTransport.java | 2 + .../MockMcpTransport.java | 6 +- .../client/AbstractMcpAsyncClientTests.java | 12 +- .../client/AbstractMcpSyncClientTests.java | 12 +- .../client/HttpSseMcpAsyncClientTests.java | 6 +- .../client/HttpSseMcpSyncClientTests.java | 6 +- .../client/StdioMcpAsyncClientTests.java | 4 +- .../client/StdioMcpSyncClientTests.java | 6 +- ...AbstractMcpAsyncServerDeprecatedTests.java | 466 +++++ .../server/AbstractMcpAsyncServerTests.java | 127 +- .../AbstractMcpSyncServerDeprecatedTests.java | 433 +++++ .../server/AbstractMcpSyncServerTests.java | 131 +- .../server/BaseMcpAsyncServerTests.java | 5 + ...rvletSseMcpAsyncServerDeprecatedTests.java | 26 + .../server/ServletSseMcpAsyncServerTests.java | 10 +- ...ervletSseMcpSyncServerDeprecatedTests.java | 26 + .../server/ServletSseMcpSyncServerTests.java | 10 +- .../StdioMcpAsyncServerDeprecatedTests.java | 25 + .../server/StdioMcpAsyncServerTests.java | 7 +- .../StdioMcpSyncServerDeprecatedTests.java | 25 + .../server/StdioMcpSyncServerTests.java | 10 +- .../server/transport/BlockingInputStream.java | 69 - ...rverTransportProviderIntegrationTests.java | 493 +++++ .../StdioServerTransportProviderTests.java | 227 +++ ...nTests.java => McpClientSessionTests.java} | 20 +- migration-0.8.0.md | 328 ++++ 76 files changed, 9969 insertions(+), 1127 deletions(-) create mode 100644 mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java create mode 100644 mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerDeprecatedTests.java create mode 100644 mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerDeprecatecTests.java create mode 100644 mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/legacy/WebFluxSseIntegrationTests.java create mode 100644 mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java create mode 100644 mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportDeprecatedTests.java create mode 100644 mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationDeprecatedTests.java create mode 100644 mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseSyncServerTransportDeprecatedTests.java create mode 100644 mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerDeprecatedTests.java create mode 100644 mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerDeprecatedTests.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransport.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java create mode 100644 mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerDeprecatedTests.java create mode 100644 mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerDeprecatedTests.java create mode 100644 mcp/src/test/java/io/modelcontextprotocol/server/BaseMcpAsyncServerTests.java create mode 100644 mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpAsyncServerDeprecatedTests.java create mode 100644 mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpSyncServerDeprecatedTests.java create mode 100644 mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpAsyncServerDeprecatedTests.java create mode 100644 mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpSyncServerDeprecatedTests.java delete mode 100644 mcp/src/test/java/io/modelcontextprotocol/server/transport/BlockingInputStream.java create mode 100644 mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java create mode 100644 mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java rename mcp/src/test/java/io/modelcontextprotocol/spec/{DefaultMcpSessionTests.java => McpClientSessionTests.java} (90%) create mode 100644 migration-0.8.0.md diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java index 8ea65fd78..b0dfa89c0 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java @@ -9,7 +9,7 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.spec.ClientMcpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; @@ -58,7 +58,7 @@ * "https://spec.modelcontextprotocol.io/specification/basic/transports/#http-with-sse">MCP * HTTP with SSE Transport Specification */ -public class WebFluxSseClientTransport implements ClientMcpTransport { +public class WebFluxSseClientTransport implements McpClientTransport { private static final Logger logger = LoggerFactory.getLogger(WebFluxSseClientTransport.class); diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransport.java index bed7293ee..fb0b581e0 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransport.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransport.java @@ -60,7 +60,10 @@ * @author Alexandros Pappas * @see ServerMcpTransport * @see ServerSentEvent + * @deprecated This class will be removed in 0.9.0. Use + * {@link WebFluxSseServerTransportProvider}. */ +@Deprecated public class WebFluxSseServerTransport implements ServerMcpTransport { private static final Logger logger = LoggerFactory.getLogger(WebFluxSseServerTransport.class); @@ -182,16 +185,16 @@ public Mono sendMessage(McpSchema.JSONRPCMessage message) { try {// @formatter:off String jsonText = objectMapper.writeValueAsString(message); ServerSentEvent event = ServerSentEvent.builder() - .event(MESSAGE_EVENT_TYPE) - .data(jsonText) - .build(); + .event(MESSAGE_EVENT_TYPE) + .data(jsonText) + .build(); logger.debug("Attempting to broadcast message to {} active sessions", sessions.size()); List failedSessions = sessions.values().stream() - .filter(session -> session.messageSink.tryEmitNext(event).isFailure()) - .map(session -> session.id) - .toList(); + .filter(session -> session.messageSink.tryEmitNext(event).isFailure()) + .map(session -> session.id) + .toList(); if (failedSessions.isEmpty()) { logger.debug("Successfully broadcast message to all sessions"); @@ -407,4 +410,4 @@ void close() { } -} +} \ No newline at end of file diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java new file mode 100644 index 000000000..cf3eeae03 --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java @@ -0,0 +1,351 @@ +package io.modelcontextprotocol.server.transport; + +import java.io.IOException; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpServerTransport; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.spec.ServerMcpTransport; +import io.modelcontextprotocol.util.Assert; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.Exceptions; +import reactor.core.publisher.Flux; +import reactor.core.publisher.FluxSink; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; + +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.codec.ServerSentEvent; +import org.springframework.web.reactive.function.server.RouterFunction; +import org.springframework.web.reactive.function.server.RouterFunctions; +import org.springframework.web.reactive.function.server.ServerRequest; +import org.springframework.web.reactive.function.server.ServerResponse; + +/** + * Server-side implementation of the MCP (Model Context Protocol) HTTP transport using + * Server-Sent Events (SSE). This implementation provides a bidirectional communication + * channel between MCP clients and servers using HTTP POST for client-to-server messages + * and SSE for server-to-client messages. + * + *

+ * Key features: + *

    + *
  • Implements the {@link McpServerTransportProvider} interface that allows managing + * {@link McpServerSession} instances and enabling their communication with the + * {@link McpServerTransport} abstraction.
  • + *
  • Uses WebFlux for non-blocking request handling and SSE support
  • + *
  • Maintains client sessions for reliable message delivery
  • + *
  • Supports graceful shutdown with session cleanup
  • + *
  • Thread-safe message broadcasting to multiple clients
  • + *
+ * + *

+ * The transport sets up two main endpoints: + *

    + *
  • SSE endpoint (/sse) - For establishing SSE connections with clients
  • + *
  • Message endpoint (configurable) - For receiving JSON-RPC messages from clients
  • + *
+ * + *

+ * This implementation is thread-safe and can handle multiple concurrent client + * connections. It uses {@link ConcurrentHashMap} for session management and Project + * Reactor's non-blocking APIs for message processing and delivery. + * + * @author Christian Tzolov + * @author Alexandros Pappas + * @author Dariusz Jędrzejczyk + * @see McpServerTransport + * @see ServerSentEvent + */ +public class WebFluxSseServerTransportProvider implements McpServerTransportProvider { + + private static final Logger logger = LoggerFactory.getLogger(WebFluxSseServerTransportProvider.class); + + /** + * Event type for JSON-RPC messages sent through the SSE connection. + */ + public static final String MESSAGE_EVENT_TYPE = "message"; + + /** + * Event type for sending the message endpoint URI to clients. + */ + public static final String ENDPOINT_EVENT_TYPE = "endpoint"; + + /** + * Default SSE endpoint path as specified by the MCP transport specification. + */ + public static final String DEFAULT_SSE_ENDPOINT = "/sse"; + + private final ObjectMapper objectMapper; + + private final String messageEndpoint; + + private final String sseEndpoint; + + private final RouterFunction routerFunction; + + private McpServerSession.Factory sessionFactory; + + /** + * Map of active client sessions, keyed by session ID. + */ + private final ConcurrentHashMap sessions = new ConcurrentHashMap<>(); + + /** + * Flag indicating if the transport is shutting down. + */ + private volatile boolean isClosing = false; + + /** + * Constructs a new WebFlux SSE server transport provider instance. + * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + * of MCP messages. Must not be null. + * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC + * messages. This endpoint will be communicated to clients during SSE connection + * setup. Must not be null. + * @throws IllegalArgumentException if either parameter is null + */ + public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) { + Assert.notNull(objectMapper, "ObjectMapper must not be null"); + Assert.notNull(messageEndpoint, "Message endpoint must not be null"); + Assert.notNull(sseEndpoint, "SSE endpoint must not be null"); + + this.objectMapper = objectMapper; + this.messageEndpoint = messageEndpoint; + this.sseEndpoint = sseEndpoint; + this.routerFunction = RouterFunctions.route() + .GET(this.sseEndpoint, this::handleSseConnection) + .POST(this.messageEndpoint, this::handleMessage) + .build(); + } + + /** + * Constructs a new WebFlux SSE server transport provider instance with the default + * SSE endpoint. + * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + * of MCP messages. Must not be null. + * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC + * messages. This endpoint will be communicated to clients during SSE connection + * setup. Must not be null. + * @throws IllegalArgumentException if either parameter is null + */ + public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint) { + this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT); + } + + @Override + public void setSessionFactory(McpServerSession.Factory sessionFactory) { + this.sessionFactory = sessionFactory; + } + + /** + * Broadcasts a JSON-RPC message to all connected clients through their SSE + * connections. The message is serialized to JSON and sent as a server-sent event to + * each active session. + * + *

+ * The method: + *

    + *
  • Serializes the message to JSON
  • + *
  • Creates a server-sent event with the message data
  • + *
  • Attempts to send the event to all active sessions
  • + *
  • Tracks and reports any delivery failures
  • + *
+ * @param method The JSON-RPC method to send to clients + * @param params The method parameters to send to clients + * @return A Mono that completes when the message has been sent to all sessions, or + * errors if any session fails to receive the message + */ + @Override + public Mono notifyClients(String method, Map params) { + if (sessions.isEmpty()) { + logger.debug("No active sessions to broadcast message to"); + return Mono.empty(); + } + + logger.debug("Attempting to broadcast message to {} active sessions", sessions.size()); + + return Flux.fromStream(sessions.values().stream()) + .flatMap(session -> session.sendNotification(method, params) + .doOnError(e -> logger.error("Failed to " + "send message to session " + "{}: {}", session.getId(), + e.getMessage())) + .onErrorComplete()) + .then(); + } + + // FIXME: This javadoc makes claims about using isClosing flag but it's not actually + // doing that. + /** + * Initiates a graceful shutdown of all the sessions. This method ensures all active + * sessions are properly closed and cleaned up. + * + *

+ * The shutdown process: + *

    + *
  • Marks the transport as closing to prevent new connections
  • + *
  • Closes each active session
  • + *
  • Removes closed sessions from the sessions map
  • + *
  • Times out after 5 seconds if shutdown takes too long
  • + *
+ * @return A Mono that completes when all sessions have been closed + */ + @Override + public Mono closeGracefully() { + return Flux.fromIterable(sessions.values()) + .doFirst(() -> logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size())) + .flatMap(McpServerSession::closeGracefully) + .then(); + } + + /** + * Returns the WebFlux router function that defines the transport's HTTP endpoints. + * This router function should be integrated into the application's web configuration. + * + *

+ * The router function defines two endpoints: + *

    + *
  • GET {sseEndpoint} - For establishing SSE connections
  • + *
  • POST {messageEndpoint} - For receiving client messages
  • + *
+ * @return The configured {@link RouterFunction} for handling HTTP requests + */ + public RouterFunction getRouterFunction() { + return this.routerFunction; + } + + /** + * Handles new SSE connection requests from clients. Creates a new session for each + * connection and sets up the SSE event stream. + * @param request The incoming server request + * @return A Mono which emits a response with the SSE event stream + */ + private Mono handleSseConnection(ServerRequest request) { + if (isClosing) { + return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down"); + } + + return ServerResponse.ok() + .contentType(MediaType.TEXT_EVENT_STREAM) + .body(Flux.>create(sink -> { + WebFluxMcpSessionTransport sessionTransport = new WebFluxMcpSessionTransport(sink); + + McpServerSession session = sessionFactory.create(sessionTransport); + String sessionId = session.getId(); + + logger.debug("Created new SSE connection for session: {}", sessionId); + sessions.put(sessionId, session); + + // Send initial endpoint event + logger.debug("Sending initial endpoint event to session: {}", sessionId); + sink.next(ServerSentEvent.builder() + .event(ENDPOINT_EVENT_TYPE) + .data(messageEndpoint + "?sessionId=" + sessionId) + .build()); + sink.onCancel(() -> { + logger.debug("Session {} cancelled", sessionId); + sessions.remove(sessionId); + }); + }), ServerSentEvent.class); + } + + /** + * Handles incoming JSON-RPC messages from clients. Deserializes the message and + * processes it through the configured message handler. + * + *

+ * The handler: + *

    + *
  • Deserializes the incoming JSON-RPC message
  • + *
  • Passes it through the message handler chain
  • + *
  • Returns appropriate HTTP responses based on processing results
  • + *
  • Handles various error conditions with appropriate error responses
  • + *
+ * @param request The incoming server request containing the JSON-RPC message + * @return A Mono emitting the response indicating the message processing result + */ + private Mono handleMessage(ServerRequest request) { + if (isClosing) { + return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down"); + } + + if (request.queryParam("sessionId").isEmpty()) { + return ServerResponse.badRequest().bodyValue(new McpError("Session ID missing in message endpoint")); + } + + McpServerSession session = sessions.get(request.queryParam("sessionId").get()); + + return request.bodyToMono(String.class).flatMap(body -> { + try { + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body); + return session.handle(message).flatMap(response -> ServerResponse.ok().build()).onErrorResume(error -> { + logger.error("Error processing message: {}", error.getMessage()); + // TODO: instead of signalling the error, just respond with 200 OK + // - the error is signalled on the SSE connection + // return ServerResponse.ok().build(); + return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR) + .bodyValue(new McpError(error.getMessage())); + }); + } + catch (IllegalArgumentException | IOException e) { + logger.error("Failed to deserialize message: {}", e.getMessage()); + return ServerResponse.badRequest().bodyValue(new McpError("Invalid message format")); + } + }); + } + + private class WebFluxMcpSessionTransport implements McpServerTransport { + + private final FluxSink> sink; + + public WebFluxMcpSessionTransport(FluxSink> sink) { + this.sink = sink; + } + + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message) { + return Mono.fromSupplier(() -> { + try { + return objectMapper.writeValueAsString(message); + } + catch (IOException e) { + throw Exceptions.propagate(e); + } + }).doOnNext(jsonText -> { + ServerSentEvent event = ServerSentEvent.builder() + .event(MESSAGE_EVENT_TYPE) + .data(jsonText) + .build(); + sink.next(event); + }).doOnError(e -> { + // TODO log with sessionid + Throwable exception = Exceptions.unwrap(e); + sink.error(exception); + }).then(); + } + + @Override + public T unmarshalFrom(Object data, TypeReference typeRef) { + return objectMapper.convertValue(data, typeRef); + } + + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(sink::complete); + } + + @Override + public void close() { + sink.complete(); + } + + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java index 4cd24c621..57bcd191b 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java @@ -16,7 +16,7 @@ import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; import io.modelcontextprotocol.server.McpServer; import io.modelcontextprotocol.server.McpServerFeatures; -import io.modelcontextprotocol.server.transport.WebFluxSseServerTransport; +import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; @@ -30,9 +30,9 @@ import io.modelcontextprotocol.spec.McpSchema.Tool; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; +import reactor.core.publisher.Mono; import reactor.netty.DisposableServer; import reactor.netty.http.server.HttpServer; import reactor.test.StepVerifier; @@ -44,8 +44,8 @@ import org.springframework.web.reactive.function.server.RouterFunctions; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.awaitility.Awaitility.await; +import static org.mockito.Mockito.mock; public class WebFluxSseIntegrationTests { @@ -55,16 +55,16 @@ public class WebFluxSseIntegrationTests { private DisposableServer httpServer; - private WebFluxSseServerTransport mcpServerTransport; + private WebFluxSseServerTransportProvider mcpServerTransportProvider; ConcurrentHashMap clientBulders = new ConcurrentHashMap<>(); @BeforeEach public void before() { - this.mcpServerTransport = new WebFluxSseServerTransport(new ObjectMapper(), MESSAGE_ENDPOINT); + this.mcpServerTransportProvider = new WebFluxSseServerTransportProvider(new ObjectMapper(), MESSAGE_ENDPOINT); - HttpHandler httpHandler = RouterFunctions.toHttpHandler(mcpServerTransport.getRouterFunction()); + HttpHandler httpHandler = RouterFunctions.toHttpHandler(mcpServerTransportProvider.getRouterFunction()); ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); this.httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); @@ -84,57 +84,43 @@ public void after() { // --------------------------------------- // Sampling Tests // --------------------------------------- - @Test - void testCreateMessageWithoutInitialization() { - var mcpAsyncServer = McpServer.async(mcpServerTransport).serverInfo("test-server", "1.0.0").build(); - - var messages = List - .of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Test message"))); - var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); - - var request = new McpSchema.CreateMessageRequest(messages, modelPrefs, null, - McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), Map.of()); - - StepVerifier.create(mcpAsyncServer.createMessage(request)).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized. Call the initialize method first!"); - }); - } - @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "httpclient", "webflux" }) void testCreateMessageWithoutSamplingCapabilities(String clientType) { - var mcpAsyncServer = McpServer.async(mcpServerTransport).serverInfo("test-server", "1.0.0").build(); - var clientBuilder = clientBulders.get(clientType); - var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")).build(); + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { - InitializeResult initResult = client.initialize(); - assertThat(initResult).isNotNull(); + exchange.createMessage(mock(McpSchema.CreateMessageRequest.class)).block(); - var messages = List - .of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Test message"))); - var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); + return Mono.just(mock(CallToolResult.class)); + }); + + McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").tools(tool).build(); - var request = new McpSchema.CreateMessageRequest(messages, modelPrefs, null, - McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), Map.of()); + // Create client without sampling capabilities + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")).build(); - StepVerifier.create(mcpAsyncServer.createMessage(request)).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) + assertThat(client.initialize()).isNotNull(); + + try { + client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class) .hasMessage("Client must be configured with sampling capabilities"); - }); + } } @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "httpclient", "webflux" }) void testCreateMessageSuccess(String clientType) throws InterruptedException { + // Client var clientBuilder = clientBulders.get(clientType); - var mcpAsyncServer = McpServer.async(mcpServerTransport).serverInfo("test-server", "1.0.0").build(); - Function samplingHandler = request -> { assertThat(request.messages()).hasSize(1); assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); @@ -143,29 +129,54 @@ void testCreateMessageSuccess(String clientType) throws InterruptedException { CreateMessageResult.StopReason.STOP_SEQUENCE); }; - var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) .capabilities(ClientCapabilities.builder().sampling().build()) .sampling(samplingHandler) .build(); - InitializeResult initResult = client.initialize(); + // Server + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var messages = List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("Test message"))); + var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); + + var craeteMessageRequest = new McpSchema.CreateMessageRequest(messages, modelPrefs, null, + McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), + Map.of()); + + StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); assertThat(initResult).isNotNull(); - var messages = List - .of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Test message"))); - var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); - - var request = new McpSchema.CreateMessageRequest(messages, modelPrefs, null, - McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), Map.of()); - - StepVerifier.create(mcpAsyncServer.createMessage(request)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.role()).isEqualTo(Role.USER); - assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); - assertThat(result.model()).isEqualTo("MockModelName"); - assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); - }).verifyComplete(); + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + mcpClient.close(); + mcpServer.close(); } // --------------------------------------- @@ -179,8 +190,8 @@ void testRootsSuccess(String clientType) { List roots = List.of(new Root("uri1://", "root1"), new Root("uri2://", "root2")); AtomicReference> rootsRef = new AtomicReference<>(); - var mcpServer = McpServer.sync(mcpServerTransport) - .rootsChangeConsumer(rootsUpdate -> rootsRef.set(rootsUpdate)) + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) .build(); var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) @@ -192,8 +203,6 @@ void testRootsSuccess(String clientType) { assertThat(rootsRef.get()).isNull(); - assertThat(mcpServer.listRoots().roots()).containsAll(roots); - mcpClient.rootsListChangedNotification(); await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { @@ -222,23 +231,33 @@ void testRootsSuccess(String clientType) { @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "httpclient", "webflux" }) void testRootsWithoutCapability(String clientType) { + var clientBuilder = clientBulders.get(clientType); - var mcpServer = McpServer.sync(mcpServerTransport).rootsChangeConsumer(rootsUpdate -> { - }).build(); + McpServerFeatures.SyncToolSpecification tool = new McpServerFeatures.SyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + exchange.listRoots(); // try to list roots + + return mock(CallToolResult.class); + }); + + var mcpServer = McpServer.sync(mcpServerTransportProvider).rootsChangeHandler((exchange, rootsUpdate) -> { + }).tools(tool).build(); // Create client without roots capability - var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()) // No - // roots - // capability - .build(); + // No roots capability + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()).build(); - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + assertThat(mcpClient.initialize()).isNotNull(); // Attempt to list roots should fail - assertThatThrownBy(() -> mcpServer.listRoots().roots()).isInstanceOf(McpError.class) - .hasMessage("Roots not supported"); + try { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class).hasMessage("Roots not supported"); + } mcpClient.close(); mcpServer.close(); @@ -246,12 +265,12 @@ void testRootsWithoutCapability(String clientType) { @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "httpclient", "webflux" }) - void testRootsWithEmptyRootsList(String clientType) { + void testRootsNotifciationWithEmptyRootsList(String clientType) { var clientBuilder = clientBulders.get(clientType); AtomicReference> rootsRef = new AtomicReference<>(); - var mcpServer = McpServer.sync(mcpServerTransport) - .rootsChangeConsumer(rootsUpdate -> rootsRef.set(rootsUpdate)) + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) .build(); var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) @@ -273,7 +292,7 @@ void testRootsWithEmptyRootsList(String clientType) { @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "httpclient", "webflux" }) - void testRootsWithMultipleConsumers(String clientType) { + void testRootsWithMultipleHandlers(String clientType) { var clientBuilder = clientBulders.get(clientType); List roots = List.of(new Root("uri1://", "root1")); @@ -281,9 +300,9 @@ void testRootsWithMultipleConsumers(String clientType) { AtomicReference> rootsRef1 = new AtomicReference<>(); AtomicReference> rootsRef2 = new AtomicReference<>(); - var mcpServer = McpServer.sync(mcpServerTransport) - .rootsChangeConsumer(rootsUpdate -> rootsRef1.set(rootsUpdate)) - .rootsChangeConsumer(rootsUpdate -> rootsRef2.set(rootsUpdate)) + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef1.set(rootsUpdate)) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef2.set(rootsUpdate)) .build(); var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) @@ -313,8 +332,8 @@ void testRootsServerCloseWithActiveSubscription(String clientType) { List roots = List.of(new Root("uri1://", "root1")); AtomicReference> rootsRef = new AtomicReference<>(); - var mcpServer = McpServer.sync(mcpServerTransport) - .rootsChangeConsumer(rootsUpdate -> rootsRef.set(rootsUpdate)) + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) .build(); var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) @@ -356,8 +375,8 @@ void testToolCallSuccess(String clientType) { var clientBuilder = clientBulders.get(clientType); var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); - McpServerFeatures.SyncToolRegistration tool1 = new McpServerFeatures.SyncToolRegistration( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), request -> { + McpServerFeatures.SyncToolSpecification tool1 = new McpServerFeatures.SyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { // perform a blocking call to a remote service String response = RestClient.create() .get() @@ -368,7 +387,7 @@ void testToolCallSuccess(String clientType) { return callResponse; }); - var mcpServer = McpServer.sync(mcpServerTransport) + var mcpServer = McpServer.sync(mcpServerTransportProvider) .capabilities(ServerCapabilities.builder().tools(true).build()) .tools(tool1) .build(); @@ -396,8 +415,8 @@ void testToolListChangeHandlingSuccess(String clientType) { var clientBuilder = clientBulders.get(clientType); var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); - McpServerFeatures.SyncToolRegistration tool1 = new McpServerFeatures.SyncToolRegistration( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), request -> { + McpServerFeatures.SyncToolSpecification tool1 = new McpServerFeatures.SyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { // perform a blocking call to a remote service String response = RestClient.create() .get() @@ -408,7 +427,7 @@ void testToolListChangeHandlingSuccess(String clientType) { return callResponse; }); - var mcpServer = McpServer.sync(mcpServerTransport) + var mcpServer = McpServer.sync(mcpServerTransportProvider) .capabilities(ServerCapabilities.builder().tools(true).build()) .tools(tool1) .build(); @@ -446,8 +465,8 @@ void testToolListChangeHandlingSuccess(String clientType) { }); // Add a new tool - McpServerFeatures.SyncToolRegistration tool2 = new McpServerFeatures.SyncToolRegistration( - new McpSchema.Tool("tool2", "tool2 description", emptyJsonSchema), request -> callResponse); + McpServerFeatures.SyncToolSpecification tool2 = new McpServerFeatures.SyncToolSpecification( + new McpSchema.Tool("tool2", "tool2 description", emptyJsonSchema), (exchange, request) -> callResponse); mcpServer.addTool(tool2); @@ -459,4 +478,21 @@ void testToolListChangeHandlingSuccess(String clientType) { mcpServer.close(); } + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testInitialize(String clientType) { + + var clientBuilder = clientBulders.get(clientType); + + var mcpServer = McpServer.sync(mcpServerTransportProvider).build(); + + var mcpClient = clientBuilder.build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + mcpClient.close(); + mcpServer.close(); + } + } diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java index 0dccb27a7..2dd587d4f 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java @@ -7,7 +7,7 @@ import java.time.Duration; import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; -import io.modelcontextprotocol.spec.ClientMcpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; import org.junit.jupiter.api.Timeout; import org.testcontainers.containers.GenericContainer; import org.testcontainers.containers.wait.strategy.Wait; @@ -32,7 +32,7 @@ class WebFluxSseMcpAsyncClientTests extends AbstractMcpAsyncClientTests { .waitingFor(Wait.forHttp("/").forStatusCode(404)); @Override - protected ClientMcpTransport createMcpTransport() { + protected McpClientTransport createMcpTransport() { return new WebFluxSseClientTransport(WebClient.builder().baseUrl(host)); } diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java index f5cab7b73..72b390ddd 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java @@ -7,7 +7,7 @@ import java.time.Duration; import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; -import io.modelcontextprotocol.spec.ClientMcpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; import org.junit.jupiter.api.Timeout; import org.testcontainers.containers.GenericContainer; import org.testcontainers.containers.wait.strategy.Wait; @@ -32,7 +32,7 @@ class WebFluxSseMcpSyncClientTests extends AbstractMcpSyncClientTests { .waitingFor(Wait.forHttp("/").forStatusCode(404)); @Override - protected ClientMcpTransport createMcpTransport() { + protected McpClientTransport createMcpTransport() { return new WebFluxSseClientTransport(WebClient.builder().baseUrl(host)); } diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerDeprecatedTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerDeprecatedTests.java new file mode 100644 index 000000000..b460284ee --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerDeprecatedTests.java @@ -0,0 +1,55 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.server.transport.WebFluxSseServerTransport; +import io.modelcontextprotocol.spec.ServerMcpTransport; +import org.junit.jupiter.api.Timeout; +import reactor.netty.DisposableServer; +import reactor.netty.http.server.HttpServer; + +import org.springframework.http.server.reactive.HttpHandler; +import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; +import org.springframework.web.reactive.function.server.RouterFunctions; + +/** + * Tests for {@link McpAsyncServer} using {@link WebFluxSseServerTransport}. + * + * @author Christian Tzolov + */ +@Deprecated +@Timeout(15) // Giving extra time beyond the client timeout +class WebFluxSseMcpAsyncServerDeprecatedTests extends AbstractMcpAsyncServerDeprecatedTests { + + private static final int PORT = 8181; + + private static final String MESSAGE_ENDPOINT = "/mcp/message"; + + private DisposableServer httpServer; + + @Override + protected ServerMcpTransport createMcpTransport() { + var transport = new WebFluxSseServerTransport(new ObjectMapper(), MESSAGE_ENDPOINT); + + HttpHandler httpHandler = RouterFunctions.toHttpHandler(transport.getRouterFunction()); + ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); + httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); + + return transport; + } + + @Override + protected void onStart() { + } + + @Override + protected void onClose() { + if (httpServer != null) { + httpServer.disposeNow(); + } + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerTests.java index 1ed0d99b5..5fa787ab6 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerTests.java @@ -5,8 +5,8 @@ package io.modelcontextprotocol.server; import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.server.transport.WebFluxSseServerTransport; -import io.modelcontextprotocol.spec.ServerMcpTransport; +import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider; +import io.modelcontextprotocol.spec.McpServerTransportProvider; import org.junit.jupiter.api.Timeout; import reactor.netty.DisposableServer; import reactor.netty.http.server.HttpServer; @@ -16,7 +16,7 @@ import org.springframework.web.reactive.function.server.RouterFunctions; /** - * Tests for {@link McpAsyncServer} using {@link WebFluxSseServerTransport}. + * Tests for {@link McpAsyncServer} using {@link WebFluxSseServerTransportProvider}. * * @author Christian Tzolov */ @@ -30,14 +30,13 @@ class WebFluxSseMcpAsyncServerTests extends AbstractMcpAsyncServerTests { private DisposableServer httpServer; @Override - protected ServerMcpTransport createMcpTransport() { - var transport = new WebFluxSseServerTransport(new ObjectMapper(), MESSAGE_ENDPOINT); + protected McpServerTransportProvider createMcpTransportProvider() { + var transportProvider = new WebFluxSseServerTransportProvider(new ObjectMapper(), MESSAGE_ENDPOINT); - HttpHandler httpHandler = RouterFunctions.toHttpHandler(transport.getRouterFunction()); + HttpHandler httpHandler = RouterFunctions.toHttpHandler(transportProvider.getRouterFunction()); ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); - - return transport; + return transportProvider; } @Override diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerDeprecatecTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerDeprecatecTests.java new file mode 100644 index 000000000..be2bf6c7f --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerDeprecatecTests.java @@ -0,0 +1,55 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.server.transport.WebFluxSseServerTransport; +import io.modelcontextprotocol.spec.ServerMcpTransport; +import org.junit.jupiter.api.Timeout; +import reactor.netty.DisposableServer; +import reactor.netty.http.server.HttpServer; + +import org.springframework.http.server.reactive.HttpHandler; +import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; +import org.springframework.web.reactive.function.server.RouterFunctions; + +/** + * Tests for {@link McpSyncServer} using {@link WebFluxSseServerTransport}. + * + * @author Christian Tzolov + */ +@Deprecated +@Timeout(15) // Giving extra time beyond the client timeout +class WebFluxSseMcpSyncServerDeprecatecTests extends AbstractMcpSyncServerDeprecatedTests { + + private static final int PORT = 8182; + + private static final String MESSAGE_ENDPOINT = "/mcp/message"; + + private DisposableServer httpServer; + + private WebFluxSseServerTransport transport; + + @Override + protected ServerMcpTransport createMcpTransport() { + transport = new WebFluxSseServerTransport(new ObjectMapper(), MESSAGE_ENDPOINT); + return transport; + } + + @Override + protected void onStart() { + HttpHandler httpHandler = RouterFunctions.toHttpHandler(transport.getRouterFunction()); + ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); + httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); + } + + @Override + protected void onClose() { + if (httpServer != null) { + httpServer.disposeNow(); + } + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerTests.java index 4db00dd47..d3672e3f3 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerTests.java @@ -5,8 +5,8 @@ package io.modelcontextprotocol.server; import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.server.transport.WebFluxSseServerTransport; -import io.modelcontextprotocol.spec.ServerMcpTransport; +import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider; +import io.modelcontextprotocol.spec.McpServerTransportProvider; import org.junit.jupiter.api.Timeout; import reactor.netty.DisposableServer; import reactor.netty.http.server.HttpServer; @@ -16,7 +16,7 @@ import org.springframework.web.reactive.function.server.RouterFunctions; /** - * Tests for {@link McpSyncServer} using {@link WebFluxSseServerTransport}. + * Tests for {@link McpSyncServer} using {@link WebFluxSseServerTransportProvider}. * * @author Christian Tzolov */ @@ -29,17 +29,17 @@ class WebFluxSseMcpSyncServerTests extends AbstractMcpSyncServerTests { private DisposableServer httpServer; - private WebFluxSseServerTransport transport; + private WebFluxSseServerTransportProvider transportProvider; @Override - protected ServerMcpTransport createMcpTransport() { - transport = new WebFluxSseServerTransport(new ObjectMapper(), MESSAGE_ENDPOINT); - return transport; + protected McpServerTransportProvider createMcpTransportProvider() { + transportProvider = new WebFluxSseServerTransportProvider(new ObjectMapper(), MESSAGE_ENDPOINT); + return transportProvider; } @Override protected void onStart() { - HttpHandler httpHandler = RouterFunctions.toHttpHandler(transport.getRouterFunction()); + HttpHandler httpHandler = RouterFunctions.toHttpHandler(transportProvider.getRouterFunction()); ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); } diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/legacy/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/legacy/WebFluxSseIntegrationTests.java new file mode 100644 index 000000000..981e114c9 --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/legacy/WebFluxSseIntegrationTests.java @@ -0,0 +1,459 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ +package io.modelcontextprotocol.server.legacy; + +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.McpServerFeatures; +import io.modelcontextprotocol.server.transport.WebFluxSseServerTransport; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.InitializeResult; +import io.modelcontextprotocol.spec.McpSchema.Role; +import io.modelcontextprotocol.spec.McpSchema.Root; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import reactor.netty.DisposableServer; +import reactor.netty.http.server.HttpServer; +import reactor.test.StepVerifier; + +import org.springframework.http.server.reactive.HttpHandler; +import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; +import org.springframework.web.client.RestClient; +import org.springframework.web.reactive.function.client.WebClient; +import org.springframework.web.reactive.function.server.RouterFunctions; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.awaitility.Awaitility.await; + +public class WebFluxSseIntegrationTests { + + private static final int PORT = 8182; + + private static final String MESSAGE_ENDPOINT = "/mcp/message"; + + private DisposableServer httpServer; + + private WebFluxSseServerTransport mcpServerTransport; + + ConcurrentHashMap clientBulders = new ConcurrentHashMap<>(); + + @BeforeEach + public void before() { + + this.mcpServerTransport = new WebFluxSseServerTransport(new ObjectMapper(), MESSAGE_ENDPOINT); + + HttpHandler httpHandler = RouterFunctions.toHttpHandler(mcpServerTransport.getRouterFunction()); + ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); + this.httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); + + clientBulders.put("httpclient", McpClient.sync(new HttpClientSseClientTransport("http://localhost:" + PORT))); + clientBulders.put("webflux", + McpClient.sync(new WebFluxSseClientTransport(WebClient.builder().baseUrl("http://localhost:" + PORT)))); + + } + + @AfterEach + public void after() { + if (httpServer != null) { + httpServer.disposeNow(); + } + } + + // --------------------------------------- + // Sampling Tests + // --------------------------------------- + @Test + void testCreateMessageWithoutInitialization() { + var mcpAsyncServer = McpServer.async(mcpServerTransport).serverInfo("test-server", "1.0.0").build(); + + var messages = List.of(new McpSchema.SamplingMessage(Role.USER, new McpSchema.TextContent("Test message"))); + var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); + + var request = new CreateMessageRequest(messages, modelPrefs, null, + CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), Map.of()); + + StepVerifier.create(mcpAsyncServer.createMessage(request)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class) + .hasMessage("Client must be initialized. Call the initialize method first!"); + }); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateMessageWithoutSamplingCapabilities(String clientType) { + + var mcpAsyncServer = McpServer.async(mcpServerTransport).serverInfo("test-server", "1.0.0").build(); + + var clientBuilder = clientBulders.get(clientType); + + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")).build(); + + InitializeResult initResult = client.initialize(); + assertThat(initResult).isNotNull(); + + var messages = List.of(new McpSchema.SamplingMessage(Role.USER, new McpSchema.TextContent("Test message"))); + var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); + + var request = new CreateMessageRequest(messages, modelPrefs, null, + CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), Map.of()); + + StepVerifier.create(mcpAsyncServer.createMessage(request)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class) + .hasMessage("Client must be configured with sampling capabilities"); + }); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateMessageSuccess(String clientType) throws InterruptedException { + + var clientBuilder = clientBulders.get(clientType); + + var mcpAsyncServer = McpServer.async(mcpServerTransport).serverInfo("test-server", "1.0.0").build(); + + Function samplingHandler = request -> { + assertThat(request.messages()).hasSize(1); + assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); + + return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", + CreateMessageResult.StopReason.STOP_SEQUENCE); + }; + + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build(); + + InitializeResult initResult = client.initialize(); + assertThat(initResult).isNotNull(); + + var messages = List.of(new McpSchema.SamplingMessage(Role.USER, new McpSchema.TextContent("Test message"))); + var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); + + var request = new CreateMessageRequest(messages, modelPrefs, null, + CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), Map.of()); + + StepVerifier.create(mcpAsyncServer.createMessage(request)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }).verifyComplete(); + } + + // --------------------------------------- + // Roots Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testRootsSuccess(String clientType) { + var clientBuilder = clientBulders.get(clientType); + + List roots = List.of(new Root("uri1://", "root1"), new Root("uri2://", "root2")); + + AtomicReference> rootsRef = new AtomicReference<>(); + var mcpServer = McpServer.sync(mcpServerTransport) + .rootsChangeConsumer(rootsUpdate -> rootsRef.set(rootsUpdate)) + .build(); + + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(roots) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(rootsRef.get()).isNull(); + + assertThat(mcpServer.listRoots().roots()).containsAll(roots); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(roots); + }); + + // Remove a root + mcpClient.removeRoot(roots.get(0).uri()); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(roots.get(1))); + }); + + // Add a new root + var root3 = new Root("uri3://", "root3"); + mcpClient.addRoot(root3); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(roots.get(1), root3)); + }); + + mcpClient.close(); + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testRootsWithoutCapability(String clientType) { + var clientBuilder = clientBulders.get(clientType); + + var mcpServer = McpServer.sync(mcpServerTransport).rootsChangeConsumer(rootsUpdate -> { + }).build(); + + // Create client without roots capability + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()) // No + // roots + // capability + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Attempt to list roots should fail + assertThatThrownBy(() -> mcpServer.listRoots().roots()).isInstanceOf(McpError.class) + .hasMessage("Roots not supported"); + + mcpClient.close(); + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testRootsWithEmptyRootsList(String clientType) { + var clientBuilder = clientBulders.get(clientType); + + AtomicReference> rootsRef = new AtomicReference<>(); + var mcpServer = McpServer.sync(mcpServerTransport) + .rootsChangeConsumer(rootsUpdate -> rootsRef.set(rootsUpdate)) + .build(); + + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(List.of()) // Empty roots list + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).isEmpty(); + }); + + mcpClient.close(); + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testRootsWithMultipleConsumers(String clientType) { + var clientBuilder = clientBulders.get(clientType); + + List roots = List.of(new Root("uri1://", "root1")); + + AtomicReference> rootsRef1 = new AtomicReference<>(); + AtomicReference> rootsRef2 = new AtomicReference<>(); + + var mcpServer = McpServer.sync(mcpServerTransport) + .rootsChangeConsumer(rootsUpdate -> rootsRef1.set(rootsUpdate)) + .rootsChangeConsumer(rootsUpdate -> rootsRef2.set(rootsUpdate)) + .build(); + + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(roots) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef1.get()).containsAll(roots); + assertThat(rootsRef2.get()).containsAll(roots); + }); + + mcpClient.close(); + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testRootsServerCloseWithActiveSubscription(String clientType) { + + var clientBuilder = clientBulders.get(clientType); + + List roots = List.of(new Root("uri1://", "root1")); + + AtomicReference> rootsRef = new AtomicReference<>(); + var mcpServer = McpServer.sync(mcpServerTransport) + .rootsChangeConsumer(rootsUpdate -> rootsRef.set(rootsUpdate)) + .build(); + + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(roots) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(roots); + }); + + // Close server while subscription is active + mcpServer.close(); + + // Verify client can handle server closure gracefully + mcpClient.close(); + } + + // --------------------------------------- + // Tools Tests + // --------------------------------------- + + String emptyJsonSchema = """ + { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": {} + } + """; + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testToolCallSuccess(String clientType) { + + var clientBuilder = clientBulders.get(clientType); + + var callResponse = new CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + McpServerFeatures.SyncToolRegistration tool1 = new McpServerFeatures.SyncToolRegistration( + new Tool("tool1", "tool1 description", emptyJsonSchema), request -> { + // perform a blocking call to a remote service + String response = RestClient.create() + .get() + .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") + .retrieve() + .body(String.class); + assertThat(response).isNotBlank(); + return callResponse; + }); + + var mcpServer = McpServer.sync(mcpServerTransport) + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + var mcpClient = clientBuilder.build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + mcpClient.close(); + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testToolListChangeHandlingSuccess(String clientType) { + + var clientBuilder = clientBulders.get(clientType); + + var callResponse = new CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + McpServerFeatures.SyncToolRegistration tool1 = new McpServerFeatures.SyncToolRegistration( + new Tool("tool1", "tool1 description", emptyJsonSchema), request -> { + // perform a blocking call to a remote service + String response = RestClient.create() + .get() + .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") + .retrieve() + .body(String.class); + assertThat(response).isNotBlank(); + return callResponse; + }); + + var mcpServer = McpServer.sync(mcpServerTransport) + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + AtomicReference> rootsRef = new AtomicReference<>(); + var mcpClient = clientBuilder.toolsChangeConsumer(toolsUpdate -> { + // perform a blocking call to a remote service + String response = RestClient.create() + .get() + .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") + .retrieve() + .body(String.class); + assertThat(response).isNotBlank(); + rootsRef.set(toolsUpdate); + }).build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(rootsRef.get()).isNull(); + + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + + mcpServer.notifyToolsListChanged(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(tool1.tool())); + }); + + // Remove a tool + mcpServer.removeTool("tool1"); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).isEmpty(); + }); + + // Add a new tool + McpServerFeatures.SyncToolRegistration tool2 = new McpServerFeatures.SyncToolRegistration( + new Tool("tool2", "tool2 description", emptyJsonSchema), request -> callResponse); + + mcpServer.addTool(tool2); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(tool2.tool())); + }); + + mcpClient.close(); + mcpServer.close(); + } + +} diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransport.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransport.java index 00928ec7f..23193d106 100644 --- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransport.java +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransport.java @@ -33,6 +33,9 @@ * a bridge between synchronous WebMVC operations and reactive programming patterns to * maintain compatibility with the reactive transport interface. * + * @deprecated This class will be removed in 0.9.0. Use + * {@link WebMvcSseServerTransportProvider}. + * *

* Key features: *