diff --git a/README.md b/README.md index 436104c63..7bda15006 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,9 @@ # MCP Java SDK +[![License](https://img.shields.io/badge/License-MIT-blue.svg)](https://opensource.org/license/MIT) [![Build Status](https://github.com/modelcontextprotocol/java-sdk/actions/workflows/publish-snapshot.yml/badge.svg)](https://github.com/modelcontextprotocol/java-sdk/actions/workflows/publish-snapshot.yml) +[![Maven Central](https://img.shields.io/maven-central/v/io.modelcontextprotocol.sdk/mcp.svg?label=Maven%20Central)](https://central.sonatype.com/artifact/io.modelcontextprotocol.sdk/mcp) +[![Java Version](https://img.shields.io/badge/Java-17%2B-orange)](https://www.oracle.com/java/technologies/javase/jdk17-archive-downloads.html) + A set of projects that provide Java SDK integration for the [Model Context Protocol](https://modelcontextprotocol.org/docs/concepts/architecture). This SDK enables Java applications to interact with AI models and tools through a standardized interface, supporting both synchronous and asynchronous communication patterns. @@ -43,6 +47,7 @@ Please follow the [Contributing Guidelines](CONTRIBUTING.md). - Christian Tzolov - Dariusz Jędrzejczyk +- Daniel Garnier-Moiroux ## Links @@ -50,6 +55,133 @@ Please follow the [Contributing Guidelines](CONTRIBUTING.md). - [Issue Tracker](https://github.com/modelcontextprotocol/java-sdk/issues) - [CI/CD](https://github.com/modelcontextprotocol/java-sdk/actions) +## Architecture and Design Decisions + +### Introduction + +Building a general-purpose MCP Java SDK requires making technology decisions in areas where the JDK provides limited or no support. The Java ecosystem is powerful but fragmented: multiple valid approaches exist, each with strong communities. +Our goal is not to prescribe "the one true way," but to provide a reference implementation of the MCP specification that is: + +* **Pragmatic** – makes developers productive quickly +* **Interoperable** – aligns with widely used libraries and practices +* **Pluggable** – allows alternatives where projects prefer different stacks +* **Grounded in team familiarity** – we chose technologies the team can be productive with today, while remaining open to community contributions that broaden the SDK + +### Key Choices and Considerations + +The SDK had to make decisions in the following areas: + +1. **JSON serialization** – mapping between JSON and Java types + +2. **Programming model** – supporting asynchronous processing, cancellation, and streaming while staying simple for blocking use cases + +3. **Observability** – logging and enabling integration with metrics/tracing + +4. **Remote clients and servers** – supporting both consuming MCP servers (client transport) and exposing MCP endpoints (server transport with authorization) + +The following sections explain what we chose, why it made sense, and how the choices align with the SDK's goals. + +### 1. JSON Serialization + +* **SDK Choice**: Jackson for JSON serialization and deserialization, behind an SDK abstraction (`mcp-json`) + +* **Why**: Jackson is widely adopted across the Java ecosystem, provides strong performance and a mature annotation model, and is familiar to the SDK team and many potential contributors. + +* **How we expose it**: Public APIs use a zero-dependency abstraction (`mcp-json`). Jackson is shipped as the default implementation (`mcp-jackson2`), but alternatives can be plugged in. + +* **How it fits the SDK**: This offers a pragmatic default while keeping flexibility for projects that prefer different JSON libraries. + +### 2. Programming Model + +* **SDK Choice**: Reactive Streams for public APIs, with Project Reactor as the internal implementation and a synchronous facade for blocking use cases + +* **Why**: MCP builds on JSON-RPC's asynchronous nature and defines a bidirectional protocol on top of it, enabling asynchronous and streaming interactions. MCP explicitly supports: + + * Multiple in-flight requests and responses + * Notifications that do not expect a reply + * STDIO transports for inter-process communication using pipes + * Streaming transports such as Server-Sent Events and Streamable HTTP + + These requirements call for a programming model more powerful than single-result futures like `CompletableFuture`. + + * **Reactive Streams: the Community Standard** + + Reactive Streams is a small Java specification that standardizes asynchronous stream processing with backpressure. It defines four minimal interfaces (Publisher, Subscriber, Subscription, and Processor). These interfaces are widely recognized as the standard contract for async, non-blocking pipelines in Java. + + * **Reactive Streams Implementation** + + The SDK uses Project Reactor as its implementation of the Reactive Streams specification. Reactor is mature, widely adopted, provides rich operators, and integrates well with observability through context propagation. Team familiarity also allowed us to deliver a solid foundation quickly. + We plan to convert the public API to only expose Reactive Streams interfaces. By defining the public API in terms of Reactive Streams interfaces and using Reactor internally, the SDK stays standards-based while benefiting from a practical, production-ready implementation. + + * **Synchronous Facade in the SDK** + + Not all MCP use cases require streaming pipelines. Many scenarios are as simple as "send a request and block until I get the result." + To support this, the SDK provides a synchronous facade layered on top of the reactive core. Developers can stay in a blocking model when it's enough, while still having access to asynchronous streaming when needed. + +* **How it fits the SDK**: This design balances scalability, approachability, and future evolution such as Virtual Threads and Structured Concurrency in upcoming JDKs. + +### 3. Observability + +* **SDK Choice**: SLF4J for logging; Reactor Context for observability propagation + +* **Why**: SLF4J is the de facto logging facade in Java, with broad compatibility. Reactor Context enables propagation of observability data such as correlation IDs and tracing state across async boundaries. This ensures interoperability with modern observability frameworks. + +* **How we expose it**: Public APIs log through SLF4J only, with no backend included. Observability metadata flows through Reactor pipelines. The SDK itself does not ship metrics or tracing implementations. + +* **How it fits the SDK**: This provides reliable logging by default and seamless integration with Micrometer, OpenTelemetry, or similar systems for metrics and tracing. + +### 4. Remote MCP Clients and Servers + +MCP supports both clients (applications consuming MCP servers) and servers (applications exposing MCP endpoints). The SDK provides support for both sides. + +#### Client Transport in the SDK + +* **SDK Choice**: JDK HttpClient (Java 11+) as the default client, with optional Spring WebClient support + +* **Why**: The JDK HttpClient is built-in, portable, and supports streaming responses. This keeps the default lightweight with no extra dependencies. Spring WebClient support is available for Spring-based projects. + +* **How we expose it**: MCP Client APIs are transport-agnostic. The core module ships with JDK HttpClient transport. A Spring module provides WebClient integration. + +* **How it fits the SDK**: This ensures all applications can talk to MCP servers out of the box, while allowing richer integration in Spring and other environments. + +#### Server Transport in the SDK + +* **SDK Choice**: Jakarta Servlet implementation in core, with optional Spring WebFlux and Spring WebMVC providers + +* **Why**: Servlet is the most widely deployed Java server API. WebFlux and WebMVC cover a significant part of the Spring community. Together these provide reach across blocking and non-blocking models. + +* **How we expose it**: Server APIs are transport-agnostic. Core includes Servlet support. Spring modules extend support for WebFlux and WebMVC. + +* **How it fits the SDK**: This allows developers to expose MCP servers in the most common Java environments today, while enabling other transport implementations such as Netty, Vert.x, or Helidon. + +#### Authorization in the SDK + +* **SDK Choice**: Pluggable authorization hooks for MCP servers; no built-in implementation + +* **Why**: MCP servers must restrict access to authenticated and authorized clients. Authorization needs differ across environments such as Spring Security, MicroProfile JWT, or custom solutions. Providing hooks avoids lock-in and leverages proven libraries. + +* **How we expose it**: Authorization is integrated into the server transport layer. The SDK does not include its own authorization system. + +* **How it fits the SDK**: This keeps server-side security ecosystem-neutral, while ensuring applications can plug in their preferred authorization strategy. + +### Project Structure of the SDK + +The SDK is organized into modules to separate concerns and allow adopters to bring in only what they need: +* `mcp-bom` – Dependency versions +* `mcp-core` – Reference implementation (STDIO, JDK HttpClient, Servlet) +* `mcp-json` – JSON abstraction +* `mcp-jackson2` – Jackson implementation of JSON binding +* `mcp` – Convenience bundle (core + Jackson) +* `mcp-test` – Shared testing utilities +* `mcp-spring` – Spring integrations (WebClient, WebFlux, WebMVC) + +For example, a minimal adopter may depend only on `mcp` (core + Jackson), while a Spring-based application can use `mcp-spring` for deeper framework integration. + +### Future Directions + +The SDK is designed to evolve with the Java ecosystem. Areas we are actively watching include: +Concurrency in the JDK – Virtual Threads and Structured Concurrency may simplify the synchronous API story + ## License This project is licensed under the [MIT License](LICENSE). diff --git a/mcp-bom/pom.xml b/mcp-bom/pom.xml index 83d8bc510..447c9e0bd 100644 --- a/mcp-bom/pom.xml +++ b/mcp-bom/pom.xml @@ -7,7 +7,7 @@ io.modelcontextprotocol.sdk mcp-parent - 0.12.0-SNAPSHOT + 0.18.0-SNAPSHOT mcp-bom @@ -27,12 +27,33 @@ + + io.modelcontextprotocol.sdk + mcp-core + ${project.version} + + + io.modelcontextprotocol.sdk mcp ${project.version} + + + io.modelcontextprotocol.sdk + mcp-json + ${project.version} + + + + + io.modelcontextprotocol.sdk + mcp-json-jackson2 + ${project.version} + + io.modelcontextprotocol.sdk diff --git a/mcp-core/pom.xml b/mcp-core/pom.xml new file mode 100644 index 000000000..9e23ffd79 --- /dev/null +++ b/mcp-core/pom.xml @@ -0,0 +1,235 @@ + + + 4.0.0 + + io.modelcontextprotocol.sdk + mcp-parent + 0.18.0-SNAPSHOT + + mcp-core + jar + Java MCP SDK Core + Core classes of the Java SDK implementation of the Model Context Protocol, enabling seamless integration with language models and AI tools + https://github.com/modelcontextprotocol/java-sdk + + + https://github.com/modelcontextprotocol/java-sdk + git://github.com/modelcontextprotocol/java-sdk.git + git@github.com/modelcontextprotocol/java-sdk.git + + + + + + biz.aQute.bnd + bnd-maven-plugin + ${bnd-maven-plugin.version} + + + bnd-process + + bnd-process + + + + + + + + + + + org.apache.maven.plugins + maven-jar-plugin + + + ${project.build.outputDirectory}/META-INF/MANIFEST.MF + + + + + + + + + io.modelcontextprotocol.sdk + mcp-json + 0.18.0-SNAPSHOT + + + + org.slf4j + slf4j-api + ${slf4j-api.version} + + + + com.fasterxml.jackson.core + jackson-annotations + ${jackson.version} + + + + io.projectreactor + reactor-core + + + + + + jakarta.servlet + jakarta.servlet-api + ${jakarta.servlet.version} + provided + + + + + io.modelcontextprotocol.sdk + mcp-json-jackson2 + 0.18.0-SNAPSHOT + test + + + + org.springframework + spring-webmvc + ${springframework.version} + test + + + + + io.projectreactor.netty + reactor-netty-http + test + + + + + org.springframework + spring-context + ${springframework.version} + test + + + + org.springframework + spring-test + ${springframework.version} + test + + + + org.assertj + assertj-core + ${assert4j.version} + test + + + org.junit.jupiter + junit-jupiter-api + ${junit.version} + test + + + org.junit.jupiter + junit-jupiter-params + ${junit.version} + test + + + org.mockito + mockito-core + ${mockito.version} + test + + + + + net.bytebuddy + byte-buddy + ${byte-buddy.version} + test + + + io.projectreactor + reactor-test + test + + + org.testcontainers + junit-jupiter + ${testcontainers.version} + test + + + + org.awaitility + awaitility + ${awaitility.version} + test + + + + ch.qos.logback + logback-classic + ${logback.version} + test + + + + net.javacrumbs.json-unit + json-unit-assertj + ${json-unit-assertj.version} + test + + + + + org.apache.tomcat.embed + tomcat-embed-core + ${tomcat.version} + test + + + org.apache.tomcat.embed + tomcat-embed-websocket + ${tomcat.version} + test + + + + org.testcontainers + toxiproxy + ${toxiproxy.version} + test + + + + + com.google.code.gson + gson + 2.10.1 + test + + + + + diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/LifecycleInitializer.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/LifecycleInitializer.java similarity index 89% rename from mcp/src/main/java/io/modelcontextprotocol/client/LifecycleInitializer.java rename to mcp-core/src/main/java/io/modelcontextprotocol/client/LifecycleInitializer.java index 2cc1c5dba..07d86f40e 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/LifecycleInitializer.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/LifecycleInitializer.java @@ -11,14 +11,13 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - import io.modelcontextprotocol.spec.McpClientSession; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException; import io.modelcontextprotocol.util.Assert; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import reactor.core.publisher.Mono; import reactor.core.publisher.Sinks; import reactor.util.context.ContextView; @@ -99,21 +98,30 @@ class LifecycleInitializer { */ private final Duration initializationTimeout; + /** + * Post-initialization hook to perform additional operations after every successful + * initialization. + */ + private final Function> postInitializationHook; + public LifecycleInitializer(McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo, List protocolVersions, Duration initializationTimeout, - Function sessionSupplier) { + Function sessionSupplier, + Function> postInitializationHook) { Assert.notNull(sessionSupplier, "Session supplier must not be null"); Assert.notNull(clientCapabilities, "Client capabilities must not be null"); Assert.notNull(clientInfo, "Client info must not be null"); Assert.notEmpty(protocolVersions, "Protocol versions must not be empty"); Assert.notNull(initializationTimeout, "Initialization timeout must not be null"); + Assert.notNull(postInitializationHook, "Post-initialization hook must not be null"); this.sessionSupplier = sessionSupplier; this.clientCapabilities = clientCapabilities; this.clientInfo = clientInfo; this.protocolVersions = Collections.unmodifiableList(new ArrayList<>(protocolVersions)); this.initializationTimeout = initializationTimeout; + this.postInitializationHook = postInitializationHook; } /** @@ -148,10 +156,6 @@ interface Initialization { } - /** - * Default implementation of the {@link Initialization} interface that manages the MCP - * client initialization process. - */ private static class DefaultInitialization implements Initialization { /** @@ -199,29 +203,20 @@ private void setMcpClientSession(McpClientSession mcpClientSession) { this.mcpClientSession.set(mcpClientSession); } - /** - * Returns a Mono that completes when the MCP client initialization is complete. - * This allows subscribers to wait for the initialization to finish before - * proceeding with further operations. - * @return A Mono that emits the result of the MCP initialization process - */ private Mono await() { return this.initSink.asMono(); } - /** - * Completes the initialization process with the given result. It caches the - * result and emits it to all subscribers waiting for the initialization to - * complete. - * @param initializeResult The result of the MCP initialization process - */ private void complete(McpSchema.InitializeResult initializeResult) { - // first ensure the result is cached - this.result.set(initializeResult); // inform all the subscribers waiting for the initialization this.initSink.emitValue(initializeResult, Sinks.EmitFailureHandler.FAIL_FAST); } + private void cacheResult(McpSchema.InitializeResult initializeResult) { + // first ensure the result is cached + this.result.set(initializeResult); + } + private void error(Throwable t) { this.initSink.emitError(t, Sinks.EmitFailureHandler.FAIL_FAST); } @@ -263,7 +258,7 @@ public void handleException(Throwable t) { } // Providing an empty operation since we are only interested in triggering // the implicit initialization step. - withIntitialization("re-initializing", result -> Mono.empty()).subscribe(); + this.withInitialization("re-initializing", result -> Mono.empty()).subscribe(); } } @@ -275,7 +270,7 @@ public void handleException(Throwable t) { * @param operation The operation to execute when the client is initialized * @return A Mono that completes with the result of the operation */ - public Mono withIntitialization(String actionName, Function> operation) { + public Mono withInitialization(String actionName, Function> operation) { return Mono.deferContextual(ctx -> { DefaultInitialization newInit = new DefaultInitialization(); DefaultInitialization previous = this.initializationRef.compareAndExchange(null, newInit); @@ -283,19 +278,24 @@ public Mono withIntitialization(String actionName, Function initializationJob = needsToInitialize ? doInitialize(newInit, ctx) - : previous.await(); + Mono initializationJob = needsToInitialize + ? this.doInitialize(newInit, this.postInitializationHook, ctx) : previous.await(); return initializationJob.map(initializeResult -> this.initializationRef.get()) .timeout(this.initializationTimeout) .onErrorResume(ex -> { + this.initializationRef.compareAndSet(newInit, null); return Mono.error(new RuntimeException("Client failed to initialize " + actionName, ex)); }) - .flatMap(operation); + .flatMap(res -> operation.apply(res) + .contextWrite(c -> c.put(McpAsyncClient.NEGOTIATED_PROTOCOL_VERSION, + res.initializeResult().protocolVersion()))); }); } - private Mono doInitialize(DefaultInitialization initialization, ContextView ctx) { + private Mono doInitialize(DefaultInitialization initialization, + Function> postInitOperation, ContextView ctx) { + initialization.setMcpClientSession(this.sessionSupplier.apply(ctx)); McpClientSession mcpClientSession = initialization.mcpSession(); @@ -321,7 +321,12 @@ private Mono doInitialize(DefaultInitialization init } return mcpClientSession.sendNotification(McpSchema.METHOD_NOTIFICATION_INITIALIZED, null) + .contextWrite( + c -> c.put(McpAsyncClient.NEGOTIATED_PROTOCOL_VERSION, initializeResult.protocolVersion())) .thenReturn(initializeResult); + }).flatMap(initializeResult -> { + initialization.cacheResult(initializeResult); + return postInitOperation.apply(initialization).thenReturn(initializeResult); }).doOnNext(initialization::complete).onErrorResume(ex -> { initialization.error(ex); return Mono.error(ex); diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java similarity index 82% rename from mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java rename to mcp-core/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java index eb6d42f68..e6a09cd08 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java @@ -15,14 +15,13 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.function.Function; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import com.fasterxml.jackson.core.type.TypeReference; - +import io.modelcontextprotocol.client.LifecycleInitializer.Initialization; +import io.modelcontextprotocol.json.TypeRef; +import io.modelcontextprotocol.json.schema.JsonSchemaValidator; import io.modelcontextprotocol.spec.McpClientSession; +import io.modelcontextprotocol.spec.McpClientSession.NotificationHandler; +import io.modelcontextprotocol.spec.McpClientSession.RequestHandler; import io.modelcontextprotocol.spec.McpClientTransport; -import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; @@ -36,10 +35,10 @@ import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; import io.modelcontextprotocol.spec.McpSchema.PaginatedRequest; import io.modelcontextprotocol.spec.McpSchema.Root; -import io.modelcontextprotocol.spec.McpClientSession.NotificationHandler; -import io.modelcontextprotocol.spec.McpClientSession.RequestHandler; import io.modelcontextprotocol.util.Assert; import io.modelcontextprotocol.util.Utils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -76,6 +75,7 @@ * @author Dariusz Jędrzejczyk * @author Christian Tzolov * @author Jihoon Kim + * @author Anurag Pant * @see McpClient * @see McpSchema * @see McpClientSession @@ -85,27 +85,29 @@ public class McpAsyncClient { private static final Logger logger = LoggerFactory.getLogger(McpAsyncClient.class); - private static final TypeReference VOID_TYPE_REFERENCE = new TypeReference<>() { + private static final TypeRef VOID_TYPE_REFERENCE = new TypeRef<>() { }; - public static final TypeReference OBJECT_TYPE_REF = new TypeReference<>() { + public static final TypeRef OBJECT_TYPE_REF = new TypeRef<>() { }; - public static final TypeReference PAGINATED_REQUEST_TYPE_REF = new TypeReference<>() { + public static final TypeRef PAGINATED_REQUEST_TYPE_REF = new TypeRef<>() { }; - public static final TypeReference INITIALIZE_RESULT_TYPE_REF = new TypeReference<>() { + public static final TypeRef INITIALIZE_RESULT_TYPE_REF = new TypeRef<>() { }; - public static final TypeReference CREATE_MESSAGE_REQUEST_TYPE_REF = new TypeReference<>() { + public static final TypeRef CREATE_MESSAGE_REQUEST_TYPE_REF = new TypeRef<>() { }; - public static final TypeReference LOGGING_MESSAGE_NOTIFICATION_TYPE_REF = new TypeReference<>() { + public static final TypeRef LOGGING_MESSAGE_NOTIFICATION_TYPE_REF = new TypeRef<>() { }; - public static final TypeReference PROGRESS_NOTIFICATION_TYPE_REF = new TypeReference<>() { + public static final TypeRef PROGRESS_NOTIFICATION_TYPE_REF = new TypeRef<>() { }; + public static final String NEGOTIATED_PROTOCOL_VERSION = "io.modelcontextprotocol.client.negotiated-protocol-version"; + /** * Client capabilities. */ @@ -153,16 +155,33 @@ public class McpAsyncClient { */ private final LifecycleInitializer initializer; + /** + * JSON schema validator to use for validating tool responses against output schemas. + */ + private final JsonSchemaValidator jsonSchemaValidator; + + /** + * Cached tool output schemas. + */ + private final ConcurrentHashMap> toolsOutputSchemaCache; + + /** + * Whether to enable automatic schema caching during callTool operations. + */ + private final boolean enableCallToolSchemaCaching; + /** * Create a new McpAsyncClient with the given transport and session request-response * timeout. * @param transport the transport to use. * @param requestTimeout the session request-response timeout. * @param initializationTimeout the max timeout to await for the client-server - * @param features the MCP Client supported features. + * @param jsonSchemaValidator the JSON schema validator to use for validating tool + * @param features the MCP Client supported features. responses against output + * schemas. */ McpAsyncClient(McpClientTransport transport, Duration requestTimeout, Duration initializationTimeout, - McpClientFeatures.Async features) { + JsonSchemaValidator jsonSchemaValidator, McpClientFeatures.Async features) { Assert.notNull(transport, "Transport must not be null"); Assert.notNull(requestTimeout, "Request timeout must not be null"); @@ -172,6 +191,9 @@ public class McpAsyncClient { this.clientCapabilities = features.clientCapabilities(); this.transport = transport; this.roots = new ConcurrentHashMap<>(features.roots()); + this.jsonSchemaValidator = jsonSchemaValidator; + this.toolsOutputSchemaCache = new ConcurrentHashMap<>(); + this.enableCallToolSchemaCaching = features.enableCallToolSchemaCaching(); // Request Handlers Map> requestHandlers = new HashMap<>(); @@ -274,9 +296,30 @@ public class McpAsyncClient { notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_PROGRESS, asyncProgressNotificationHandler(progressConsumersFinal)); + Function> postInitializationHook = init -> { + + if (init.initializeResult().capabilities().tools() == null || !enableCallToolSchemaCaching) { + return Mono.empty(); + } + + return this.listToolsInternal(init, McpSchema.FIRST_PAGE).doOnNext(listToolsResult -> { + listToolsResult.tools() + .forEach(tool -> logger.debug("Tool {} schema: {}", tool.name(), tool.outputSchema())); + if (enableCallToolSchemaCaching && listToolsResult.tools() != null) { + // Cache tools output schema + listToolsResult.tools() + .stream() + .filter(tool -> tool.outputSchema() != null) + .forEach(tool -> this.toolsOutputSchemaCache.put(tool.name(), tool.outputSchema())); + } + }).then(); + }; + this.initializer = new LifecycleInitializer(clientCapabilities, clientInfo, transport.protocolVersions(), initializationTimeout, ctx -> new McpClientSession(requestTimeout, transport, requestHandlers, - notificationHandlers, con -> con.contextWrite(ctx))); + notificationHandlers, con -> con.contextWrite(ctx)), + postInitializationHook); + this.transport.setExceptionHandler(this.initializer::handleException); } @@ -361,6 +404,7 @@ public Mono closeGracefully() { // -------------------------- // Initialization // -------------------------- + /** * The initialization phase should be the first interaction between client and server. * The client will ensure it happens in case it has not been explicitly called and in @@ -388,7 +432,7 @@ public Mono closeGracefully() { *

*/ public Mono initialize() { - return this.initializer.withIntitialization("by explicit API call", init -> Mono.just(init.initializeResult())); + return this.initializer.withInitialization("by explicit API call", init -> Mono.just(init.initializeResult())); } // -------------------------- @@ -400,13 +444,14 @@ public Mono initialize() { * @return A Mono that completes with the server's ping response */ public Mono ping() { - return this.initializer.withIntitialization("pinging the server", + return this.initializer.withInitialization("pinging the server", init -> init.mcpSession().sendRequest(McpSchema.METHOD_PING, null, OBJECT_TYPE_REF)); } // -------------------------- // Roots // -------------------------- + /** * Adds a new root to the client's root list. * @param root The root to add. @@ -481,7 +526,7 @@ public Mono removeRoot(String rootUri) { * @return A Mono that completes when the notification is sent. */ public Mono rootsListChangedNotification() { - return this.initializer.withIntitialization("sending roots list changed notification", + return this.initializer.withInitialization("sending roots list changed notification", init -> init.mcpSession().sendNotification(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED)); } @@ -512,7 +557,7 @@ private RequestHandler samplingCreateMessageHandler() { // -------------------------- private RequestHandler elicitationCreateHandler() { return params -> { - ElicitRequest request = transport.unmarshalFrom(params, new TypeReference<>() { + ElicitRequest request = transport.unmarshalFrom(params, new TypeRef<>() { }); return this.elicitationHandler.apply(request); @@ -522,10 +567,10 @@ private RequestHandler elicitationCreateHandler() { // -------------------------- // Tools // -------------------------- - private static final TypeReference CALL_TOOL_RESULT_TYPE_REF = new TypeReference<>() { + private static final TypeRef CALL_TOOL_RESULT_TYPE_REF = new TypeRef<>() { }; - private static final TypeReference LIST_TOOLS_RESULT_TYPE_REF = new TypeReference<>() { + private static final TypeRef LIST_TOOLS_RESULT_TYPE_REF = new TypeRef<>() { }; /** @@ -540,27 +585,57 @@ private RequestHandler elicitationCreateHandler() { * @see #listTools() */ public Mono callTool(McpSchema.CallToolRequest callToolRequest) { - return this.initializer.withIntitialization("calling tools", init -> { + return this.initializer.withInitialization("calling tool", init -> { if (init.initializeResult().capabilities().tools() == null) { return Mono.error(new IllegalStateException("Server does not provide tools capability")); } + return init.mcpSession() - .sendRequest(McpSchema.METHOD_TOOLS_CALL, callToolRequest, CALL_TOOL_RESULT_TYPE_REF); + .sendRequest(McpSchema.METHOD_TOOLS_CALL, callToolRequest, CALL_TOOL_RESULT_TYPE_REF) + .flatMap(result -> Mono.just(validateToolResult(callToolRequest.name(), result))); }); } + private McpSchema.CallToolResult validateToolResult(String toolName, McpSchema.CallToolResult result) { + + if (!this.enableCallToolSchemaCaching || result == null || result.isError() == Boolean.TRUE) { + // if tool schema caching is disabled or tool call resulted in an error - skip + // validation and return the result as it is + return result; + } + + Map optOutputSchema = this.toolsOutputSchemaCache.get(toolName); + + if (optOutputSchema == null) { + logger.warn( + "Calling a tool with no outputSchema is not expected to return result with structured content, but got: {}", + result.structuredContent()); + return result; + } + + // Validate the tool output against the cached output schema + var validation = this.jsonSchemaValidator.validate(optOutputSchema, result.structuredContent()); + + if (!validation.valid()) { + logger.warn("Tool call result validation failed: {}", validation.errorMessage()); + throw new IllegalArgumentException("Tool call result validation failed: " + validation.errorMessage()); + } + + return result; + } + /** * Retrieves the list of all tools provided by the server. * @return A Mono that emits the list of all tools result */ public Mono listTools() { - return this.listTools(McpSchema.FIRST_PAGE) - .expand(result -> (result.nextCursor() != null) ? this.listTools(result.nextCursor()) : Mono.empty()) - .reduce(new McpSchema.ListToolsResult(new ArrayList<>(), null), (allToolsResult, result) -> { - allToolsResult.tools().addAll(result.tools()); - return allToolsResult; - }) - .map(result -> new McpSchema.ListToolsResult(Collections.unmodifiableList(result.tools()), null)); + return this.listTools(McpSchema.FIRST_PAGE).expand(result -> { + String next = result.nextCursor(); + return (next != null && !next.isEmpty()) ? this.listTools(next) : Mono.empty(); + }).reduce(new McpSchema.ListToolsResult(new ArrayList<>(), null), (allToolsResult, result) -> { + allToolsResult.tools().addAll(result.tools()); + return allToolsResult; + }).map(result -> new McpSchema.ListToolsResult(Collections.unmodifiableList(result.tools()), null)); } /** @@ -569,14 +644,26 @@ public Mono listTools() { * @return A Mono that emits the list of tools result */ public Mono listTools(String cursor) { - return this.initializer.withIntitialization("listing tools", init -> { - if (init.initializeResult().capabilities().tools() == null) { - return Mono.error(new IllegalStateException("Server does not provide tools capability")); - } - return init.mcpSession() - .sendRequest(McpSchema.METHOD_TOOLS_LIST, new McpSchema.PaginatedRequest(cursor), - LIST_TOOLS_RESULT_TYPE_REF); - }); + return this.initializer.withInitialization("listing tools", init -> this.listToolsInternal(init, cursor)); + } + + private Mono listToolsInternal(Initialization init, String cursor) { + + if (init.initializeResult().capabilities().tools() == null) { + return Mono.error(new IllegalStateException("Server does not provide tools capability")); + } + return init.mcpSession() + .sendRequest(McpSchema.METHOD_TOOLS_LIST, new McpSchema.PaginatedRequest(cursor), + LIST_TOOLS_RESULT_TYPE_REF) + .doOnNext(result -> { + if (this.enableCallToolSchemaCaching && result.tools() != null) { + // Cache tools output schema + result.tools() + .stream() + .filter(tool -> tool.outputSchema() != null) + .forEach(tool -> this.toolsOutputSchemaCache.put(tool.name(), tool.outputSchema())); + } + }); } private NotificationHandler asyncToolsChangeNotificationHandler( @@ -596,13 +683,13 @@ private NotificationHandler asyncToolsChangeNotificationHandler( // Resources // -------------------------- - private static final TypeReference LIST_RESOURCES_RESULT_TYPE_REF = new TypeReference<>() { + private static final TypeRef LIST_RESOURCES_RESULT_TYPE_REF = new TypeRef<>() { }; - private static final TypeReference READ_RESOURCE_RESULT_TYPE_REF = new TypeReference<>() { + private static final TypeRef READ_RESOURCE_RESULT_TYPE_REF = new TypeRef<>() { }; - private static final TypeReference LIST_RESOURCE_TEMPLATES_RESULT_TYPE_REF = new TypeReference<>() { + private static final TypeRef LIST_RESOURCE_TEMPLATES_RESULT_TYPE_REF = new TypeRef<>() { }; /** @@ -633,7 +720,7 @@ public Mono listResources() { * @see #readResource(McpSchema.Resource) */ public Mono listResources(String cursor) { - return this.initializer.withIntitialization("listing resources", init -> { + return this.initializer.withInitialization("listing resources", init -> { if (init.initializeResult().capabilities().resources() == null) { return Mono.error(new IllegalStateException("Server does not provide the resources capability")); } @@ -665,7 +752,7 @@ public Mono readResource(McpSchema.Resource resour * @see McpSchema.ReadResourceResult */ public Mono readResource(McpSchema.ReadResourceRequest readResourceRequest) { - return this.initializer.withIntitialization("reading resources", init -> { + return this.initializer.withInitialization("reading resources", init -> { if (init.initializeResult().capabilities().resources() == null) { return Mono.error(new IllegalStateException("Server does not provide the resources capability")); } @@ -703,7 +790,7 @@ public Mono listResourceTemplates() { * @see McpSchema.ListResourceTemplatesResult */ public Mono listResourceTemplates(String cursor) { - return this.initializer.withIntitialization("listing resource templates", init -> { + return this.initializer.withInitialization("listing resource templates", init -> { if (init.initializeResult().capabilities().resources() == null) { return Mono.error(new IllegalStateException("Server does not provide the resources capability")); } @@ -723,7 +810,7 @@ public Mono listResourceTemplates(String * @see #unsubscribeResource(McpSchema.UnsubscribeRequest) */ public Mono subscribeResource(McpSchema.SubscribeRequest subscribeRequest) { - return this.initializer.withIntitialization("subscribing to resources", init -> init.mcpSession() + return this.initializer.withInitialization("subscribing to resources", init -> init.mcpSession() .sendRequest(McpSchema.METHOD_RESOURCES_SUBSCRIBE, subscribeRequest, VOID_TYPE_REFERENCE)); } @@ -737,7 +824,7 @@ public Mono subscribeResource(McpSchema.SubscribeRequest subscribeRequest) * @see #subscribeResource(McpSchema.SubscribeRequest) */ public Mono unsubscribeResource(McpSchema.UnsubscribeRequest unsubscribeRequest) { - return this.initializer.withIntitialization("unsubscribing from resources", init -> init.mcpSession() + return this.initializer.withInitialization("unsubscribing from resources", init -> init.mcpSession() .sendRequest(McpSchema.METHOD_RESOURCES_UNSUBSCRIBE, unsubscribeRequest, VOID_TYPE_REFERENCE)); } @@ -756,7 +843,7 @@ private NotificationHandler asyncResourcesUpdatedNotificationHandler( List, Mono>> resourcesUpdateConsumers) { return params -> { McpSchema.ResourcesUpdatedNotification resourcesUpdatedNotification = transport.unmarshalFrom(params, - new TypeReference<>() { + new TypeRef<>() { }); return readResource(new McpSchema.ReadResourceRequest(resourcesUpdatedNotification.uri())) @@ -773,10 +860,10 @@ private NotificationHandler asyncResourcesUpdatedNotificationHandler( // -------------------------- // Prompts // -------------------------- - private static final TypeReference LIST_PROMPTS_RESULT_TYPE_REF = new TypeReference<>() { + private static final TypeRef LIST_PROMPTS_RESULT_TYPE_REF = new TypeRef<>() { }; - private static final TypeReference GET_PROMPT_RESULT_TYPE_REF = new TypeReference<>() { + private static final TypeRef GET_PROMPT_RESULT_TYPE_REF = new TypeRef<>() { }; /** @@ -803,7 +890,7 @@ public Mono listPrompts() { * @see #getPrompt(GetPromptRequest) */ public Mono listPrompts(String cursor) { - return this.initializer.withIntitialization("listing prompts", init -> init.mcpSession() + return this.initializer.withInitialization("listing prompts", init -> init.mcpSession() .sendRequest(McpSchema.METHOD_PROMPT_LIST, new PaginatedRequest(cursor), LIST_PROMPTS_RESULT_TYPE_REF)); } @@ -817,7 +904,7 @@ public Mono listPrompts(String cursor) { * @see #listPrompts() */ public Mono getPrompt(GetPromptRequest getPromptRequest) { - return this.initializer.withIntitialization("getting prompts", init -> init.mcpSession() + return this.initializer.withInitialization("getting prompts", init -> init.mcpSession() .sendRequest(McpSchema.METHOD_PROMPT_GET, getPromptRequest, GET_PROMPT_RESULT_TYPE_REF)); } @@ -835,14 +922,6 @@ private NotificationHandler asyncPromptsChangeNotificationHandler( // -------------------------- // Logging // -------------------------- - /** - * Create a notification handler for logging notifications from the server. This - * handler automatically distributes logging messages to all registered consumers. - * @param loggingConsumers List of consumers that will be notified when a logging - * message is received. Each consumer receives the logging message notification. - * @return A NotificationHandler that processes log notifications by distributing the - * message to all registered consumers - */ private NotificationHandler asyncLoggingNotificationHandler( List>> loggingConsumers) { @@ -868,7 +947,7 @@ public Mono setLoggingLevel(LoggingLevel loggingLevel) { return Mono.error(new IllegalArgumentException("Logging level must not be null")); } - return this.initializer.withIntitialization("setting logging level", init -> { + return this.initializer.withInitialization("setting logging level", init -> { if (init.initializeResult().capabilities().logging() == null) { return Mono.error(new IllegalStateException("Server's Logging capabilities are not enabled!")); } @@ -877,15 +956,6 @@ public Mono setLoggingLevel(LoggingLevel loggingLevel) { }); } - /** - * Create a notification handler for progress notifications from the server. This - * handler automatically distributes progress notifications to all registered - * consumers. - * @param progressConsumers List of consumers that will be notified when a progress - * message is received. Each consumer receives the progress notification. - * @return A NotificationHandler that processes progress notifications by distributing - * the message to all registered consumers - */ private NotificationHandler asyncProgressNotificationHandler( List>> progressConsumers) { @@ -911,7 +981,7 @@ void setProtocolVersions(List protocolVersions) { // -------------------------- // Completions // -------------------------- - private static final TypeReference COMPLETION_COMPLETE_RESULT_TYPE_REF = new TypeReference<>() { + private static final TypeRef COMPLETION_COMPLETE_RESULT_TYPE_REF = new TypeRef<>() { }; /** @@ -925,7 +995,7 @@ void setProtocolVersions(List protocolVersions) { * @see McpSchema.CompleteResult */ public Mono completeCompletion(McpSchema.CompleteRequest completeRequest) { - return this.initializer.withIntitialization("complete completions", init -> init.mcpSession() + return this.initializer.withInitialization("complete completions", init -> init.mcpSession() .sendRequest(McpSchema.METHOD_COMPLETION_COMPLETE, completeRequest, COMPLETION_COMPLETE_RESULT_TYPE_REF)); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/McpClient.java similarity index 86% rename from mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java rename to mcp-core/src/main/java/io/modelcontextprotocol/client/McpClient.java index c8af28ac1..c9989f832 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/McpClient.java @@ -4,17 +4,10 @@ package io.modelcontextprotocol.client; -import java.time.Duration; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.function.Consumer; -import java.util.function.Function; - +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.json.schema.JsonSchemaValidator; import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpTransport; import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; @@ -22,9 +15,19 @@ import io.modelcontextprotocol.spec.McpSchema.ElicitResult; import io.modelcontextprotocol.spec.McpSchema.Implementation; import io.modelcontextprotocol.spec.McpSchema.Root; +import io.modelcontextprotocol.spec.McpTransport; import io.modelcontextprotocol.util.Assert; import reactor.core.publisher.Mono; +import java.time.Duration; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.Supplier; + /** * Factory class for creating Model Context Protocol (MCP) clients. MCP is a protocol that * enables AI models to interact with external tools and resources through a standardized @@ -72,6 +75,7 @@ * .resourcesChangeConsumer(resources -> Mono.fromRunnable(() -> System.out.println("Resources updated: " + resources))) * .promptsChangeConsumer(prompts -> Mono.fromRunnable(() -> System.out.println("Prompts updated: " + prompts))) * .loggingConsumer(message -> Mono.fromRunnable(() -> System.out.println("Log message: " + message))) + * .resourcesUpdateConsumer(resourceContents -> Mono.fromRunnable(() -> System.out.println("Resources contents updated: " + resourceContents))) * .build(); * } * @@ -97,6 +101,7 @@ * * @author Christian Tzolov * @author Dariusz Jędrzejczyk + * @author Anurag Pant * @see McpAsyncClient * @see McpSyncClient * @see McpTransport @@ -163,7 +168,7 @@ class SyncSpec { private ClientCapabilities capabilities; - private Implementation clientInfo = new Implementation("Java SDK MCP Client", "1.0.0"); + private Implementation clientInfo = new Implementation("Java SDK MCP Client", "0.15.0"); private final Map roots = new HashMap<>(); @@ -183,6 +188,12 @@ class SyncSpec { private Function elicitationHandler; + private Supplier contextProvider = () -> McpTransportContext.EMPTY; + + private JsonSchemaValidator jsonSchemaValidator; + + private boolean enableCallToolSchemaCaching = false; // Default to false + private SyncSpec(McpClientTransport transport) { Assert.notNull(transport, "Transport must not be null"); this.transport = transport; @@ -336,6 +347,22 @@ public SyncSpec resourcesChangeConsumer(Consumer> resou return this; } + /** + * Adds a consumer to be notified when a specific resource is updated. This allows + * the client to react to changes in individual resources, such as updates to + * their content or metadata. + * @param resourcesUpdateConsumer A consumer function that processes the updated + * resource and returns a Mono indicating the completion of the processing. Must + * not be null. + * @return This builder instance for method chaining. + * @throws IllegalArgumentException If the resourcesUpdateConsumer is null. + */ + public SyncSpec resourcesUpdateConsumer(Consumer> resourcesUpdateConsumer) { + Assert.notNull(resourcesUpdateConsumer, "Resources update consumer must not be null"); + this.resourcesUpdateConsumers.add(resourcesUpdateConsumer); + return this; + } + /** * Adds a consumer to be notified when the available prompts change. This allows * the client to react to changes in the server's prompt templates, such as new @@ -409,6 +436,48 @@ public SyncSpec progressConsumers(List> return this; } + /** + * Add a provider of {@link McpTransportContext}, providing a context before + * calling any client operation. This allows to extract thread-locals and hand + * them over to the underlying transport. + *

+ * There is no direct equivalent in {@link AsyncSpec}. To achieve the same result, + * append {@code contextWrite(McpTransportContext.KEY, context)} to any + * {@link McpAsyncClient} call. + * @param contextProvider A supplier to create a context + * @return This builder for method chaining + */ + public SyncSpec transportContextProvider(Supplier contextProvider) { + this.contextProvider = contextProvider; + return this; + } + + /** + * Add a {@link JsonSchemaValidator} to validate the JSON structure of the + * structured output. + * @param jsonSchemaValidator A validator to validate the JSON structure of the + * structured output. Must not be null. + * @return This builder for method chaining + * @throws IllegalArgumentException if jsonSchemaValidator is null + */ + public SyncSpec jsonSchemaValidator(JsonSchemaValidator jsonSchemaValidator) { + Assert.notNull(jsonSchemaValidator, "JsonSchemaValidator must not be null"); + this.jsonSchemaValidator = jsonSchemaValidator; + return this; + } + + /** + * Enables automatic schema caching during callTool operations. When a tool's + * output schema is not found in the cache, callTool will automatically fetch and + * cache all tool schemas via listTools. + * @param enableCallToolSchemaCaching true to enable, false to disable + * @return This builder instance for method chaining + */ + public SyncSpec enableCallToolSchemaCaching(boolean enableCallToolSchemaCaching) { + this.enableCallToolSchemaCaching = enableCallToolSchemaCaching; + return this; + } + /** * Create an instance of {@link McpSyncClient} with the provided configurations or * sensible defaults. @@ -418,12 +487,13 @@ public McpSyncClient build() { McpClientFeatures.Sync syncFeatures = new McpClientFeatures.Sync(this.clientInfo, this.capabilities, this.roots, this.toolsChangeConsumers, this.resourcesChangeConsumers, this.resourcesUpdateConsumers, this.promptsChangeConsumers, this.loggingConsumers, this.progressConsumers, this.samplingHandler, - this.elicitationHandler); + this.elicitationHandler, this.enableCallToolSchemaCaching); McpClientFeatures.Async asyncFeatures = McpClientFeatures.Async.fromSync(syncFeatures); - return new McpSyncClient( - new McpAsyncClient(transport, this.requestTimeout, this.initializationTimeout, asyncFeatures)); + return new McpSyncClient(new McpAsyncClient(transport, this.requestTimeout, this.initializationTimeout, + jsonSchemaValidator != null ? jsonSchemaValidator : JsonSchemaValidator.getDefault(), + asyncFeatures), this.contextProvider); } } @@ -454,7 +524,7 @@ class AsyncSpec { private ClientCapabilities capabilities; - private Implementation clientInfo = new Implementation("Spring AI MCP Client", "0.3.1"); + private Implementation clientInfo = new Implementation("Java SDK MCP Client", "0.15.0"); private final Map roots = new HashMap<>(); @@ -474,6 +544,10 @@ class AsyncSpec { private Function> elicitationHandler; + private JsonSchemaValidator jsonSchemaValidator; + + private boolean enableCallToolSchemaCaching = false; // Default to false + private AsyncSpec(McpClientTransport transport) { Assert.notNull(transport, "Transport must not be null"); this.transport = transport; @@ -720,17 +794,45 @@ public AsyncSpec progressConsumers( return this; } + /** + * Sets the JSON schema validator to use for validating tool responses against + * output schemas. + * @param jsonSchemaValidator The validator to use. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if jsonSchemaValidator is null + */ + public AsyncSpec jsonSchemaValidator(JsonSchemaValidator jsonSchemaValidator) { + Assert.notNull(jsonSchemaValidator, "JsonSchemaValidator must not be null"); + this.jsonSchemaValidator = jsonSchemaValidator; + return this; + } + + /** + * Enables automatic schema caching during callTool operations. When a tool's + * output schema is not found in the cache, callTool will automatically fetch and + * cache all tool schemas via listTools. + * @param enableCallToolSchemaCaching true to enable, false to disable + * @return This builder instance for method chaining + */ + public AsyncSpec enableCallToolSchemaCaching(boolean enableCallToolSchemaCaching) { + this.enableCallToolSchemaCaching = enableCallToolSchemaCaching; + return this; + } + /** * Create an instance of {@link McpAsyncClient} with the provided configurations * or sensible defaults. * @return a new instance of {@link McpAsyncClient}. */ public McpAsyncClient build() { + var jsonSchemaValidator = (this.jsonSchemaValidator != null) ? this.jsonSchemaValidator + : JsonSchemaValidator.getDefault(); return new McpAsyncClient(this.transport, this.requestTimeout, this.initializationTimeout, + jsonSchemaValidator, new McpClientFeatures.Async(this.clientInfo, this.capabilities, this.roots, this.toolsChangeConsumers, this.resourcesChangeConsumers, this.resourcesUpdateConsumers, this.promptsChangeConsumers, this.loggingConsumers, this.progressConsumers, - this.samplingHandler, this.elicitationHandler)); + this.samplingHandler, this.elicitationHandler, this.enableCallToolSchemaCaching)); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java similarity index 94% rename from mcp/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java rename to mcp-core/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java index 3b6550765..127d53337 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java @@ -62,6 +62,7 @@ class McpClientFeatures { * @param progressConsumers the progress consumers. * @param samplingHandler the sampling handler. * @param elicitationHandler the elicitation handler. + * @param enableCallToolSchemaCaching whether to enable call tool schema caching. */ record Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities clientCapabilities, Map roots, List, Mono>> toolsChangeConsumers, @@ -71,7 +72,8 @@ record Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c List>> loggingConsumers, List>> progressConsumers, Function> samplingHandler, - Function> elicitationHandler) { + Function> elicitationHandler, + boolean enableCallToolSchemaCaching) { /** * Create an instance and validate the arguments. @@ -84,6 +86,7 @@ record Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c * @param progressConsumers the progress consumers. * @param samplingHandler the sampling handler. * @param elicitationHandler the elicitation handler. + * @param enableCallToolSchemaCaching whether to enable call tool schema caching. */ public Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities clientCapabilities, Map roots, @@ -94,7 +97,8 @@ public Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c List>> loggingConsumers, List>> progressConsumers, Function> samplingHandler, - Function> elicitationHandler) { + Function> elicitationHandler, + boolean enableCallToolSchemaCaching) { Assert.notNull(clientInfo, "Client info must not be null"); this.clientInfo = clientInfo; @@ -113,6 +117,7 @@ public Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c this.progressConsumers = progressConsumers != null ? progressConsumers : List.of(); this.samplingHandler = samplingHandler; this.elicitationHandler = elicitationHandler; + this.enableCallToolSchemaCaching = enableCallToolSchemaCaching; } /** @@ -129,7 +134,7 @@ public Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c Function> elicitationHandler) { this(clientInfo, clientCapabilities, roots, toolsChangeConsumers, resourcesChangeConsumers, resourcesUpdateConsumers, promptsChangeConsumers, loggingConsumers, List.of(), samplingHandler, - elicitationHandler); + elicitationHandler, false); } /** @@ -187,7 +192,8 @@ public static Async fromSync(Sync syncSpec) { return new Async(syncSpec.clientInfo(), syncSpec.clientCapabilities(), syncSpec.roots(), toolsChangeConsumers, resourcesChangeConsumers, resourcesUpdateConsumers, promptsChangeConsumers, - loggingConsumers, progressConsumers, samplingHandler, elicitationHandler); + loggingConsumers, progressConsumers, samplingHandler, elicitationHandler, + syncSpec.enableCallToolSchemaCaching); } } @@ -205,6 +211,7 @@ public static Async fromSync(Sync syncSpec) { * @param progressConsumers the progress consumers. * @param samplingHandler the sampling handler. * @param elicitationHandler the elicitation handler. + * @param enableCallToolSchemaCaching whether to enable call tool schema caching. */ public record Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities clientCapabilities, Map roots, List>> toolsChangeConsumers, @@ -214,7 +221,8 @@ public record Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabili List> loggingConsumers, List> progressConsumers, Function samplingHandler, - Function elicitationHandler) { + Function elicitationHandler, + boolean enableCallToolSchemaCaching) { /** * Create an instance and validate the arguments. @@ -229,6 +237,7 @@ public record Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabili * @param progressConsumers the progress consumers. * @param samplingHandler the sampling handler. * @param elicitationHandler the elicitation handler. + * @param enableCallToolSchemaCaching whether to enable call tool schema caching. */ public Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities clientCapabilities, Map roots, List>> toolsChangeConsumers, @@ -238,7 +247,8 @@ public Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities cl List> loggingConsumers, List> progressConsumers, Function samplingHandler, - Function elicitationHandler) { + Function elicitationHandler, + boolean enableCallToolSchemaCaching) { Assert.notNull(clientInfo, "Client info must not be null"); this.clientInfo = clientInfo; @@ -257,6 +267,7 @@ public Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities cl this.progressConsumers = progressConsumers != null ? progressConsumers : List.of(); this.samplingHandler = samplingHandler; this.elicitationHandler = elicitationHandler; + this.enableCallToolSchemaCaching = enableCallToolSchemaCaching; } /** @@ -272,7 +283,7 @@ public Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities cl Function elicitationHandler) { this(clientInfo, clientCapabilities, roots, toolsChangeConsumers, resourcesChangeConsumers, resourcesUpdateConsumers, promptsChangeConsumers, loggingConsumers, List.of(), samplingHandler, - elicitationHandler); + elicitationHandler, false); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java similarity index 82% rename from mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java rename to mcp-core/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java index 33784adcd..7fdaa8941 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java @@ -5,16 +5,19 @@ package io.modelcontextprotocol.client; import java.time.Duration; +import java.util.function.Supplier; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import io.modelcontextprotocol.common.McpTransportContext; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; import io.modelcontextprotocol.spec.McpSchema.GetPromptRequest; import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; import io.modelcontextprotocol.spec.McpSchema.ListPromptsResult; import io.modelcontextprotocol.util.Assert; +import reactor.core.publisher.Mono; /** * A synchronous client implementation for the Model Context Protocol (MCP) that wraps an @@ -63,14 +66,20 @@ public class McpSyncClient implements AutoCloseable { private final McpAsyncClient delegate; + private final Supplier contextProvider; + /** * Create a new McpSyncClient with the given delegate. * @param delegate the asynchronous kernel on top of which this synchronous client * provides a blocking API. + * @param contextProvider the supplier of context before calling any non-blocking + * operation on underlying delegate */ - McpSyncClient(McpAsyncClient delegate) { + McpSyncClient(McpAsyncClient delegate, Supplier contextProvider) { Assert.notNull(delegate, "The delegate can not be null"); + Assert.notNull(contextProvider, "The contextProvider can not be null"); this.delegate = delegate; + this.contextProvider = contextProvider; } /** @@ -177,14 +186,14 @@ public boolean closeGracefully() { public McpSchema.InitializeResult initialize() { // TODO: block takes no argument here as we assume the async client is // configured with a requestTimeout at all times - return this.delegate.initialize().block(); + return withProvidedContext(this.delegate.initialize()).block(); } /** * Send a roots/list_changed notification. */ public void rootsListChangedNotification() { - this.delegate.rootsListChangedNotification().block(); + withProvidedContext(this.delegate.rootsListChangedNotification()).block(); } /** @@ -206,7 +215,7 @@ public void removeRoot(String rootUri) { * @return */ public Object ping() { - return this.delegate.ping().block(); + return withProvidedContext(this.delegate.ping()).block(); } // -------------------------- @@ -224,7 +233,8 @@ public Object ping() { * Boolean indicating if the execution failed (true) or succeeded (false/absent) */ public McpSchema.CallToolResult callTool(McpSchema.CallToolRequest callToolRequest) { - return this.delegate.callTool(callToolRequest).block(); + return withProvidedContext(this.delegate.callTool(callToolRequest)).block(); + } /** @@ -234,7 +244,7 @@ public McpSchema.CallToolResult callTool(McpSchema.CallToolRequest callToolReque * pagination if more tools are available */ public McpSchema.ListToolsResult listTools() { - return this.delegate.listTools().block(); + return withProvidedContext(this.delegate.listTools()).block(); } /** @@ -245,7 +255,8 @@ public McpSchema.ListToolsResult listTools() { * pagination if more tools are available */ public McpSchema.ListToolsResult listTools(String cursor) { - return this.delegate.listTools(cursor).block(); + return withProvidedContext(this.delegate.listTools(cursor)).block(); + } // -------------------------- @@ -257,7 +268,8 @@ public McpSchema.ListToolsResult listTools(String cursor) { * @return The list of all resources result */ public McpSchema.ListResourcesResult listResources() { - return this.delegate.listResources().block(); + return withProvidedContext(this.delegate.listResources()).block(); + } /** @@ -266,7 +278,8 @@ public McpSchema.ListResourcesResult listResources() { * @return The list of resources result */ public McpSchema.ListResourcesResult listResources(String cursor) { - return this.delegate.listResources(cursor).block(); + return withProvidedContext(this.delegate.listResources(cursor)).block(); + } /** @@ -275,7 +288,8 @@ public McpSchema.ListResourcesResult listResources(String cursor) { * @return the resource content. */ public McpSchema.ReadResourceResult readResource(McpSchema.Resource resource) { - return this.delegate.readResource(resource).block(); + return withProvidedContext(this.delegate.readResource(resource)).block(); + } /** @@ -284,7 +298,8 @@ public McpSchema.ReadResourceResult readResource(McpSchema.Resource resource) { * @return the resource content. */ public McpSchema.ReadResourceResult readResource(McpSchema.ReadResourceRequest readResourceRequest) { - return this.delegate.readResource(readResourceRequest).block(); + return withProvidedContext(this.delegate.readResource(readResourceRequest)).block(); + } /** @@ -292,7 +307,8 @@ public McpSchema.ReadResourceResult readResource(McpSchema.ReadResourceRequest r * @return The list of all resource templates result. */ public McpSchema.ListResourceTemplatesResult listResourceTemplates() { - return this.delegate.listResourceTemplates().block(); + return withProvidedContext(this.delegate.listResourceTemplates()).block(); + } /** @@ -304,7 +320,8 @@ public McpSchema.ListResourceTemplatesResult listResourceTemplates() { * @return The list of resource templates result. */ public McpSchema.ListResourceTemplatesResult listResourceTemplates(String cursor) { - return this.delegate.listResourceTemplates(cursor).block(); + return withProvidedContext(this.delegate.listResourceTemplates(cursor)).block(); + } /** @@ -317,7 +334,8 @@ public McpSchema.ListResourceTemplatesResult listResourceTemplates(String cursor * subscribe to. */ public void subscribeResource(McpSchema.SubscribeRequest subscribeRequest) { - this.delegate.subscribeResource(subscribeRequest).block(); + withProvidedContext(this.delegate.subscribeResource(subscribeRequest)).block(); + } /** @@ -326,7 +344,8 @@ public void subscribeResource(McpSchema.SubscribeRequest subscribeRequest) { * to unsubscribe from. */ public void unsubscribeResource(McpSchema.UnsubscribeRequest unsubscribeRequest) { - this.delegate.unsubscribeResource(unsubscribeRequest).block(); + withProvidedContext(this.delegate.unsubscribeResource(unsubscribeRequest)).block(); + } // -------------------------- @@ -338,7 +357,7 @@ public void unsubscribeResource(McpSchema.UnsubscribeRequest unsubscribeRequest) * @return The list of all prompts result. */ public ListPromptsResult listPrompts() { - return this.delegate.listPrompts().block(); + return withProvidedContext(this.delegate.listPrompts()).block(); } /** @@ -347,11 +366,12 @@ public ListPromptsResult listPrompts() { * @return The list of prompts result. */ public ListPromptsResult listPrompts(String cursor) { - return this.delegate.listPrompts(cursor).block(); + return withProvidedContext(this.delegate.listPrompts(cursor)).block(); + } public GetPromptResult getPrompt(GetPromptRequest getPromptRequest) { - return this.delegate.getPrompt(getPromptRequest).block(); + return withProvidedContext(this.delegate.getPrompt(getPromptRequest)).block(); } /** @@ -359,7 +379,8 @@ public GetPromptResult getPrompt(GetPromptRequest getPromptRequest) { * @param loggingLevel the min logging level */ public void setLoggingLevel(McpSchema.LoggingLevel loggingLevel) { - this.delegate.setLoggingLevel(loggingLevel).block(); + withProvidedContext(this.delegate.setLoggingLevel(loggingLevel)).block(); + } /** @@ -369,7 +390,18 @@ public void setLoggingLevel(McpSchema.LoggingLevel loggingLevel) { * @return the completion result containing suggested values. */ public McpSchema.CompleteResult completeCompletion(McpSchema.CompleteRequest completeRequest) { - return this.delegate.completeCompletion(completeRequest).block(); + return withProvidedContext(this.delegate.completeCompletion(completeRequest)).block(); + + } + + /** + * For a given action, on assembly, capture the "context" via the + * {@link #contextProvider} and store it in the Reactor context. + * @param action the action to perform + * @return the result of the action + */ + private Mono withProvidedContext(Mono action) { + return action.contextWrite(ctx -> ctx.put(McpTransportContext.KEY, this.contextProvider.get())); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java similarity index 71% rename from mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java rename to mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java index 0f3511afb..ae093316f 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java @@ -18,16 +18,18 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; - -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; - import io.modelcontextprotocol.client.transport.ResponseSubscribers.ResponseEvent; +import io.modelcontextprotocol.client.transport.customizer.McpAsyncHttpClientRequestCustomizer; +import io.modelcontextprotocol.client.transport.customizer.McpSyncHttpClientRequestCustomizer; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.TypeRef; +import io.modelcontextprotocol.spec.HttpHeaders; import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.ProtocolVersions; import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; import io.modelcontextprotocol.spec.McpTransportException; +import io.modelcontextprotocol.spec.ProtocolVersions; import io.modelcontextprotocol.util.Assert; import io.modelcontextprotocol.util.Utils; import reactor.core.Disposable; @@ -94,8 +96,8 @@ public class HttpClientSseClientTransport implements McpClientTransport { /** HTTP request builder for building requests to send messages to the server */ private final HttpRequest.Builder requestBuilder; - /** JSON object mapper for message serialization/deserialization */ - protected ObjectMapper objectMapper; + /** JSON mapper for message serialization/deserialization */ + protected McpJsonMapper jsonMapper; /** Flag indicating if the transport is in closing state */ private volatile boolean isClosing = false; @@ -112,67 +114,7 @@ public class HttpClientSseClientTransport implements McpClientTransport { /** * Customizer to modify requests before they are executed. */ - private final AsyncHttpRequestCustomizer httpRequestCustomizer; - - /** - * Creates a new transport instance with default HTTP client and object mapper. - * @param baseUri the base URI of the MCP server - * @deprecated Use {@link HttpClientSseClientTransport#builder(String)} instead. This - * constructor will be removed in future versions. - */ - @Deprecated(forRemoval = true) - public HttpClientSseClientTransport(String baseUri) { - this(HttpClient.newBuilder(), baseUri, new ObjectMapper()); - } - - /** - * Creates a new transport instance with custom HTTP client builder and object mapper. - * @param clientBuilder the HTTP client builder to use - * @param baseUri the base URI of the MCP server - * @param objectMapper the object mapper for JSON serialization/deserialization - * @throws IllegalArgumentException if objectMapper or clientBuilder is null - * @deprecated Use {@link HttpClientSseClientTransport#builder(String)} instead. This - * constructor will be removed in future versions. - */ - @Deprecated(forRemoval = true) - public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, String baseUri, ObjectMapper objectMapper) { - this(clientBuilder, baseUri, DEFAULT_SSE_ENDPOINT, objectMapper); - } - - /** - * Creates a new transport instance with custom HTTP client builder and object mapper. - * @param clientBuilder the HTTP client builder to use - * @param baseUri the base URI of the MCP server - * @param sseEndpoint the SSE endpoint path - * @param objectMapper the object mapper for JSON serialization/deserialization - * @throws IllegalArgumentException if objectMapper or clientBuilder is null - * @deprecated Use {@link HttpClientSseClientTransport#builder(String)} instead. This - * constructor will be removed in future versions. - */ - @Deprecated(forRemoval = true) - public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, String baseUri, String sseEndpoint, - ObjectMapper objectMapper) { - this(clientBuilder, HttpRequest.newBuilder(), baseUri, sseEndpoint, objectMapper); - } - - /** - * Creates a new transport instance with custom HTTP client builder, object mapper, - * and headers. - * @param clientBuilder the HTTP client builder to use - * @param requestBuilder the HTTP request builder to use - * @param baseUri the base URI of the MCP server - * @param sseEndpoint the SSE endpoint path - * @param objectMapper the object mapper for JSON serialization/deserialization - * @throws IllegalArgumentException if objectMapper, clientBuilder, or headers is null - * @deprecated Use {@link HttpClientSseClientTransport#builder(String)} instead. This - * constructor will be removed in future versions. - */ - @Deprecated(forRemoval = true) - public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, HttpRequest.Builder requestBuilder, - String baseUri, String sseEndpoint, ObjectMapper objectMapper) { - this(clientBuilder.connectTimeout(Duration.ofSeconds(10)).build(), requestBuilder, baseUri, sseEndpoint, - objectMapper); - } + private final McpAsyncHttpClientRequestCustomizer httpRequestCustomizer; /** * Creates a new transport instance with custom HTTP client builder, object mapper, @@ -181,30 +123,14 @@ public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, HttpReques * @param requestBuilder the HTTP request builder to use * @param baseUri the base URI of the MCP server * @param sseEndpoint the SSE endpoint path - * @param objectMapper the object mapper for JSON serialization/deserialization - * @throws IllegalArgumentException if objectMapper, clientBuilder, or headers is null - */ - @Deprecated(forRemoval = true) - HttpClientSseClientTransport(HttpClient httpClient, HttpRequest.Builder requestBuilder, String baseUri, - String sseEndpoint, ObjectMapper objectMapper) { - this(httpClient, requestBuilder, baseUri, sseEndpoint, objectMapper, AsyncHttpRequestCustomizer.NOOP); - } - - /** - * Creates a new transport instance with custom HTTP client builder, object mapper, - * and headers. - * @param httpClient the HTTP client to use - * @param requestBuilder the HTTP request builder to use - * @param baseUri the base URI of the MCP server - * @param sseEndpoint the SSE endpoint path - * @param objectMapper the object mapper for JSON serialization/deserialization + * @param jsonMapper the object mapper for JSON serialization/deserialization * @param httpRequestCustomizer customizer for the requestBuilder before executing * requests * @throws IllegalArgumentException if objectMapper, clientBuilder, or headers is null */ HttpClientSseClientTransport(HttpClient httpClient, HttpRequest.Builder requestBuilder, String baseUri, - String sseEndpoint, ObjectMapper objectMapper, AsyncHttpRequestCustomizer httpRequestCustomizer) { - Assert.notNull(objectMapper, "ObjectMapper must not be null"); + String sseEndpoint, McpJsonMapper jsonMapper, McpAsyncHttpClientRequestCustomizer httpRequestCustomizer) { + Assert.notNull(jsonMapper, "jsonMapper must not be null"); Assert.hasText(baseUri, "baseUri must not be empty"); Assert.hasText(sseEndpoint, "sseEndpoint must not be empty"); Assert.notNull(httpClient, "httpClient must not be null"); @@ -212,7 +138,7 @@ public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, HttpReques Assert.notNull(httpRequestCustomizer, "httpRequestCustomizer must not be null"); this.baseUri = URI.create(baseUri); this.sseEndpoint = sseEndpoint; - this.objectMapper = objectMapper; + this.jsonMapper = jsonMapper; this.httpClient = httpClient; this.requestBuilder = requestBuilder; this.httpRequestCustomizer = httpRequestCustomizer; @@ -241,16 +167,15 @@ public static class Builder { private String sseEndpoint = DEFAULT_SSE_ENDPOINT; - private HttpClient.Builder clientBuilder = HttpClient.newBuilder() - .version(HttpClient.Version.HTTP_1_1) - .connectTimeout(Duration.ofSeconds(10)); + private HttpClient.Builder clientBuilder = HttpClient.newBuilder().version(HttpClient.Version.HTTP_1_1); - private ObjectMapper objectMapper = new ObjectMapper(); + private McpJsonMapper jsonMapper; - private HttpRequest.Builder requestBuilder = HttpRequest.newBuilder() - .header("Content-Type", "application/json"); + private HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(); - private AsyncHttpRequestCustomizer httpRequestCustomizer = AsyncHttpRequestCustomizer.NOOP; + private McpAsyncHttpClientRequestCustomizer httpRequestCustomizer = McpAsyncHttpClientRequestCustomizer.NOOP; + + private Duration connectTimeout = Duration.ofSeconds(10); /** * Creates a new builder instance. @@ -339,13 +264,13 @@ public Builder customizeRequest(final Consumer requestCusto } /** - * Sets the object mapper for JSON serialization/deserialization. - * @param objectMapper the object mapper + * Sets the JSON mapper implementation to use for serialization/deserialization. + * @param jsonMapper the JSON mapper * @return this builder */ - public Builder objectMapper(ObjectMapper objectMapper) { - Assert.notNull(objectMapper, "objectMapper must not be null"); - this.objectMapper = objectMapper; + public Builder jsonMapper(McpJsonMapper jsonMapper) { + Assert.notNull(jsonMapper, "jsonMapper must not be null"); + this.jsonMapper = jsonMapper; return this; } @@ -354,16 +279,17 @@ public Builder objectMapper(ObjectMapper objectMapper) { * executing them. *

* This overrides the customizer from - * {@link #asyncHttpRequestCustomizer(AsyncHttpRequestCustomizer)}. + * {@link #asyncHttpRequestCustomizer(McpAsyncHttpClientRequestCustomizer)}. *

- * Do NOT use a blocking {@link SyncHttpRequestCustomizer} in a non-blocking - * context. Use {@link #asyncHttpRequestCustomizer(AsyncHttpRequestCustomizer)} + * Do NOT use a blocking {@link McpSyncHttpClientRequestCustomizer} in a + * non-blocking context. Use + * {@link #asyncHttpRequestCustomizer(McpAsyncHttpClientRequestCustomizer)} * instead. * @param syncHttpRequestCustomizer the request customizer * @return this builder */ - public Builder httpRequestCustomizer(SyncHttpRequestCustomizer syncHttpRequestCustomizer) { - this.httpRequestCustomizer = AsyncHttpRequestCustomizer.fromSync(syncHttpRequestCustomizer); + public Builder httpRequestCustomizer(McpSyncHttpClientRequestCustomizer syncHttpRequestCustomizer) { + this.httpRequestCustomizer = McpAsyncHttpClientRequestCustomizer.fromSync(syncHttpRequestCustomizer); return this; } @@ -372,24 +298,36 @@ public Builder httpRequestCustomizer(SyncHttpRequestCustomizer syncHttpRequestCu * executing them. *

* This overrides the customizer from - * {@link #httpRequestCustomizer(SyncHttpRequestCustomizer)}. + * {@link #httpRequestCustomizer(McpSyncHttpClientRequestCustomizer)}. *

* Do NOT use a blocking implementation in a non-blocking context. * @param asyncHttpRequestCustomizer the request customizer * @return this builder */ - public Builder asyncHttpRequestCustomizer(AsyncHttpRequestCustomizer asyncHttpRequestCustomizer) { + public Builder asyncHttpRequestCustomizer(McpAsyncHttpClientRequestCustomizer asyncHttpRequestCustomizer) { this.httpRequestCustomizer = asyncHttpRequestCustomizer; return this; } + /** + * Sets the connection timeout for the HTTP client. + * @param connectTimeout the connection timeout duration + * @return this builder + */ + public Builder connectTimeout(Duration connectTimeout) { + Assert.notNull(connectTimeout, "connectTimeout must not be null"); + this.connectTimeout = connectTimeout; + return this; + } + /** * Builds a new {@link HttpClientSseClientTransport} instance. * @return a new transport instance */ public HttpClientSseClientTransport build() { - return new HttpClientSseClientTransport(clientBuilder.build(), requestBuilder, baseUri, sseEndpoint, - objectMapper, httpRequestCustomizer); + HttpClient httpClient = this.clientBuilder.connectTimeout(this.connectTimeout).build(); + return new HttpClientSseClientTransport(httpClient, requestBuilder, baseUri, sseEndpoint, + jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, httpRequestCustomizer); } } @@ -398,14 +336,15 @@ public HttpClientSseClientTransport build() { public Mono connect(Function, Mono> handler) { var uri = Utils.resolveUri(this.baseUri, this.sseEndpoint); - return Mono.defer(() -> { + return Mono.deferContextual(ctx -> { var builder = requestBuilder.copy() .uri(uri) .header("Accept", "text/event-stream") .header("Cache-Control", "no-cache") .header(MCP_PROTOCOL_VERSION_HEADER_NAME, MCP_PROTOCOL_VERSION) .GET(); - return Mono.from(this.httpRequestCustomizer.customize(builder, "GET", uri, null)); + var transportContext = ctx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY); + return Mono.from(this.httpRequestCustomizer.customize(builder, "GET", uri, null, transportContext)); }).flatMap(requestBuilder -> Mono.create(sink -> { Disposable connection = Flux.create(sseSink -> this.httpClient .sendAsync(requestBuilder.build(), @@ -435,7 +374,7 @@ public Mono connect(Function, Mono> h } } else if (MESSAGE_EVENT_TYPE.equals(responseEvent.sseEvent().event())) { - JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, + JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(jsonMapper, responseEvent.sseEvent().data()); sink.success(); return Flux.just(message); @@ -516,7 +455,7 @@ public Mono sendMessage(JSONRPCMessage message) { private Mono serializeMessage(final JSONRPCMessage message) { return Mono.defer(() -> { try { - return Mono.just(objectMapper.writeValueAsString(message)); + return Mono.just(jsonMapper.writeValueAsString(message)); } catch (IOException e) { return Mono.error(new McpTransportException("Failed to serialize message", e)); @@ -526,12 +465,14 @@ private Mono serializeMessage(final JSONRPCMessage message) { private Mono> sendHttpPost(final String endpoint, final String body) { final URI requestUri = Utils.resolveUri(baseUri, endpoint); - return Mono.defer(() -> { + return Mono.deferContextual(ctx -> { var builder = this.requestBuilder.copy() .uri(requestUri) + .header(HttpHeaders.CONTENT_TYPE, "application/json") .header(MCP_PROTOCOL_VERSION_HEADER_NAME, MCP_PROTOCOL_VERSION) .POST(HttpRequest.BodyPublishers.ofString(body)); - return Mono.from(this.httpRequestCustomizer.customize(builder, "POST", requestUri, body)); + var transportContext = ctx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY); + return Mono.from(this.httpRequestCustomizer.customize(builder, "POST", requestUri, body, transportContext)); }).flatMap(customizedBuilder -> { var request = customizedBuilder.build(); return Mono.fromFuture(httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString())); @@ -565,8 +506,8 @@ public Mono closeGracefully() { * @return the unmarshalled object */ @Override - public T unmarshalFrom(Object data, TypeReference typeRef) { - return this.objectMapper.convertValue(data, typeRef); + public T unmarshalFrom(Object data, TypeRef typeRef) { + return this.jsonMapper.convertValue(data, typeRef); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java similarity index 76% rename from mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java rename to mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java index 93c28422a..e41f45ebb 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java @@ -11,6 +11,8 @@ import java.net.http.HttpResponse; import java.net.http.HttpResponse.BodyHandler; import java.time.Duration; +import java.util.Collections; +import java.util.Comparator; import java.util.List; import java.util.Optional; import java.util.concurrent.CompletionException; @@ -18,14 +20,14 @@ import java.util.function.Consumer; import java.util.function.Function; -import org.reactivestreams.Publisher; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; - +import io.modelcontextprotocol.client.McpAsyncClient; import io.modelcontextprotocol.client.transport.ResponseSubscribers.ResponseEvent; +import io.modelcontextprotocol.client.transport.customizer.McpAsyncHttpClientRequestCustomizer; +import io.modelcontextprotocol.client.transport.customizer.McpSyncHttpClientRequestCustomizer; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.TypeRef; +import io.modelcontextprotocol.spec.ClosedMcpTransportSession; import io.modelcontextprotocol.spec.DefaultMcpTransportSession; import io.modelcontextprotocol.spec.DefaultMcpTransportStream; import io.modelcontextprotocol.spec.HttpHeaders; @@ -38,6 +40,9 @@ import io.modelcontextprotocol.spec.ProtocolVersions; import io.modelcontextprotocol.util.Assert; import io.modelcontextprotocol.util.Utils; +import org.reactivestreams.Publisher; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import reactor.core.Disposable; import reactor.core.publisher.Flux; import reactor.core.publisher.FluxSink; @@ -74,8 +79,6 @@ public class HttpClientStreamableHttpTransport implements McpClientTransport { private static final Logger logger = LoggerFactory.getLogger(HttpClientStreamableHttpTransport.class); - private static final String MCP_PROTOCOL_VERSION = ProtocolVersions.MCP_2025_03_26; - private static final String DEFAULT_ENDPOINT = "/mcp"; /** @@ -103,7 +106,7 @@ public class HttpClientStreamableHttpTransport implements McpClientTransport { public static int BAD_REQUEST = 400; - private final ObjectMapper objectMapper; + private final McpJsonMapper jsonMapper; private final URI baseUri; @@ -113,18 +116,23 @@ public class HttpClientStreamableHttpTransport implements McpClientTransport { private final boolean resumableStreams; - private final AsyncHttpRequestCustomizer httpRequestCustomizer; + private final McpAsyncHttpClientRequestCustomizer httpRequestCustomizer; - private final AtomicReference activeSession = new AtomicReference<>(); + private final AtomicReference> activeSession = new AtomicReference<>(); private final AtomicReference, Mono>> handler = new AtomicReference<>(); private final AtomicReference> exceptionHandler = new AtomicReference<>(); - private HttpClientStreamableHttpTransport(ObjectMapper objectMapper, HttpClient httpClient, + private final List supportedProtocolVersions; + + private final String latestSupportedProtocolVersion; + + private HttpClientStreamableHttpTransport(McpJsonMapper jsonMapper, HttpClient httpClient, HttpRequest.Builder requestBuilder, String baseUri, String endpoint, boolean resumableStreams, - boolean openConnectionOnStartup, AsyncHttpRequestCustomizer httpRequestCustomizer) { - this.objectMapper = objectMapper; + boolean openConnectionOnStartup, McpAsyncHttpClientRequestCustomizer httpRequestCustomizer, + List supportedProtocolVersions) { + this.jsonMapper = jsonMapper; this.httpClient = httpClient; this.requestBuilder = requestBuilder; this.baseUri = URI.create(baseUri); @@ -133,11 +141,16 @@ private HttpClientStreamableHttpTransport(ObjectMapper objectMapper, HttpClient this.openConnectionOnStartup = openConnectionOnStartup; this.activeSession.set(createTransportSession()); this.httpRequestCustomizer = httpRequestCustomizer; + this.supportedProtocolVersions = Collections.unmodifiableList(supportedProtocolVersions); + this.latestSupportedProtocolVersion = this.supportedProtocolVersions.stream() + .sorted(Comparator.reverseOrder()) + .findFirst() + .get(); } @Override public List protocolVersions() { - return List.of(ProtocolVersions.MCP_2024_11_05, ProtocolVersions.MCP_2025_03_26); + return supportedProtocolVersions; } public static Builder builder(String baseUri) { @@ -159,23 +172,34 @@ public Mono connect(Function, Mono createTransportSession() { Function> onClose = sessionId -> sessionId == null ? Mono.empty() : createDelete(sessionId); return new DefaultMcpTransportSession(onClose); } + private McpTransportSession createClosedSession(McpTransportSession existingSession) { + var existingSessionId = Optional.ofNullable(existingSession) + .filter(session -> !(session instanceof ClosedMcpTransportSession)) + .flatMap(McpTransportSession::sessionId) + .orElse(null); + return new ClosedMcpTransportSession<>(existingSessionId); + } + private Publisher createDelete(String sessionId) { var uri = Utils.resolveUri(this.baseUri, this.endpoint); - return Mono.defer(() -> { + return Mono.deferContextual(ctx -> { var builder = this.requestBuilder.copy() .uri(uri) .header("Cache-Control", "no-cache") .header(HttpHeaders.MCP_SESSION_ID, sessionId) - .header(HttpHeaders.PROTOCOL_VERSION, MCP_PROTOCOL_VERSION) + .header(HttpHeaders.PROTOCOL_VERSION, + ctx.getOrDefault(McpAsyncClient.NEGOTIATED_PROTOCOL_VERSION, + this.latestSupportedProtocolVersion)) .DELETE(); - return Mono.from(this.httpRequestCustomizer.customize(builder, "DELETE", uri, null)); + var transportContext = ctx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY); + return Mono.from(this.httpRequestCustomizer.customize(builder, "DELETE", uri, null, transportContext)); }).flatMap(requestBuilder -> { var request = requestBuilder.build(); return Mono.fromFuture(() -> this.httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString())); @@ -205,9 +229,9 @@ private void handleException(Throwable t) { public Mono closeGracefully() { return Mono.defer(() -> { logger.debug("Graceful close triggered"); - DefaultMcpTransportSession currentSession = this.activeSession.getAndSet(createTransportSession()); + McpTransportSession currentSession = this.activeSession.getAndUpdate(this::createClosedSession); if (currentSession != null) { - return currentSession.closeGracefully(); + return Mono.from(currentSession.closeGracefully()); } return Mono.empty(); }); @@ -228,7 +252,7 @@ private Mono reconnect(McpTransportStream stream) { final McpTransportSession transportSession = this.activeSession.get(); var uri = Utils.resolveUri(this.baseUri, this.endpoint); - Disposable connection = Mono.defer(() -> { + Disposable connection = Mono.deferContextual(connectionCtx -> { HttpRequest.Builder requestBuilder = this.requestBuilder.copy(); if (transportSession != null && transportSession.sessionId().isPresent()) { @@ -241,11 +265,14 @@ private Mono reconnect(McpTransportStream stream) { } var builder = requestBuilder.uri(uri) - .header("Accept", TEXT_EVENT_STREAM) + .header(HttpHeaders.ACCEPT, TEXT_EVENT_STREAM) .header("Cache-Control", "no-cache") - .header(HttpHeaders.PROTOCOL_VERSION, MCP_PROTOCOL_VERSION) + .header(HttpHeaders.PROTOCOL_VERSION, + connectionCtx.getOrDefault(McpAsyncClient.NEGOTIATED_PROTOCOL_VERSION, + this.latestSupportedProtocolVersion)) .GET(); - return Mono.from(this.httpRequestCustomizer.customize(builder, "GET", uri, null)); + var transportContext = connectionCtx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY); + return Mono.from(this.httpRequestCustomizer.customize(builder, "GET", uri, null, transportContext)); }) .flatMapMany( requestBuilder -> Flux.create( @@ -273,7 +300,7 @@ private Mono reconnect(McpTransportStream stream) { // won't since the next version considers // removing it. McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage( - this.objectMapper, responseEvent.sseEvent().data()); + this.jsonMapper, responseEvent.sseEvent().data()); Tuple2, Iterable> idWithMessages = Tuples .of(Optional.ofNullable(responseEvent.sseEvent().id()), @@ -365,7 +392,7 @@ private BodyHandler toSendMessageBodySubscriber(FluxSink si BodyHandler responseBodyHandler = responseInfo -> { - String contentType = responseInfo.headers().firstValue("Content-Type").orElse("").toLowerCase(); + String contentType = responseInfo.headers().firstValue(HttpHeaders.CONTENT_TYPE).orElse("").toLowerCase(); if (contentType.contains(TEXT_EVENT_STREAM)) { // For SSE streams, use line subscriber that returns Void @@ -388,7 +415,7 @@ else if (contentType.contains(APPLICATION_JSON)) { public String toString(McpSchema.JSONRPCMessage message) { try { - return this.objectMapper.writeValueAsString(message); + return this.jsonMapper.writeValueAsString(message); } catch (IOException e) { throw new RuntimeException("Failed to serialize JSON-RPC message", e); @@ -405,7 +432,7 @@ public Mono sendMessage(McpSchema.JSONRPCMessage sentMessage) { var uri = Utils.resolveUri(this.baseUri, this.endpoint); String jsonBody = this.toString(sentMessage); - Disposable connection = Mono.defer(() -> { + Disposable connection = Mono.deferContextual(ctx -> { HttpRequest.Builder requestBuilder = this.requestBuilder.copy(); if (transportSession != null && transportSession.sessionId().isPresent()) { @@ -414,12 +441,16 @@ public Mono sendMessage(McpSchema.JSONRPCMessage sentMessage) { } var builder = requestBuilder.uri(uri) - .header("Accept", APPLICATION_JSON + ", " + TEXT_EVENT_STREAM) - .header("Content-Type", APPLICATION_JSON) - .header("Cache-Control", "no-cache") - .header(HttpHeaders.PROTOCOL_VERSION, MCP_PROTOCOL_VERSION) + .header(HttpHeaders.ACCEPT, APPLICATION_JSON + ", " + TEXT_EVENT_STREAM) + .header(HttpHeaders.CONTENT_TYPE, APPLICATION_JSON) + .header(HttpHeaders.CACHE_CONTROL, "no-cache") + .header(HttpHeaders.PROTOCOL_VERSION, + ctx.getOrDefault(McpAsyncClient.NEGOTIATED_PROTOCOL_VERSION, + this.latestSupportedProtocolVersion)) .POST(HttpRequest.BodyPublishers.ofString(jsonBody)); - return Mono.from(this.httpRequestCustomizer.customize(builder, "POST", uri, jsonBody)); + var transportContext = ctx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY); + return Mono + .from(this.httpRequestCustomizer.customize(builder, "POST", uri, jsonBody, transportContext)); }).flatMapMany(requestBuilder -> Flux.create(responseEventSink -> { // Create the async request with proper body subscriber selection @@ -451,15 +482,19 @@ public Mono sendMessage(McpSchema.JSONRPCMessage sentMessage) { String contentType = responseEvent.responseInfo() .headers() - .firstValue("Content-Type") + .firstValue(HttpHeaders.CONTENT_TYPE) .orElse("") .toLowerCase(); - if (contentType.isBlank()) { - logger.debug("No content type returned for POST in session {}", sessionRepresentation); + String contentLength = responseEvent.responseInfo() + .headers() + .firstValue(HttpHeaders.CONTENT_LENGTH) + .orElse(null); + + if (contentType.isBlank() || "0".equals(contentLength)) { + logger.debug("No body returned for POST in session {}", sessionRepresentation); // No content type means no response body, so we can just - // return - // an empty stream + // return an empty stream deliveredSink.success(); return Flux.empty(); } @@ -472,7 +507,7 @@ else if (contentType.contains(TEXT_EVENT_STREAM)) { // since the // next version considers removing it. McpSchema.JSONRPCMessage message = McpSchema - .deserializeJsonRpcMessage(this.objectMapper, sseEvent.data()); + .deserializeJsonRpcMessage(this.jsonMapper, sseEvent.data()); Tuple2, Iterable> idWithMessages = Tuples .of(Optional.ofNullable(sseEvent.id()), List.of(message)); @@ -495,13 +530,14 @@ else if (contentType.contains(TEXT_EVENT_STREAM)) { else if (contentType.contains(APPLICATION_JSON)) { deliveredSink.success(); String data = ((ResponseSubscribers.AggregateResponseEvent) responseEvent).data(); - if (sentMessage instanceof McpSchema.JSONRPCNotification && Utils.hasText(data)) { - logger.warn("Notification: {} received non-compliant response: {}", sentMessage, data); + if (sentMessage instanceof McpSchema.JSONRPCNotification) { + logger.warn("Notification: {} received non-compliant response: {}", sentMessage, + Utils.hasText(data) ? data : "[empty]"); return Mono.empty(); } try { - return Mono.just(McpSchema.deserializeJsonRpcMessage(objectMapper, data)); + return Mono.just(McpSchema.deserializeJsonRpcMessage(jsonMapper, data)); } catch (IOException e) { return Mono.error(new McpTransportException( @@ -575,8 +611,8 @@ private static String sessionIdOrPlaceholder(McpTransportSession transportSes } @Override - public T unmarshalFrom(Object data, TypeReference typeRef) { - return this.objectMapper.convertValue(data, typeRef); + public T unmarshalFrom(Object data, TypeRef typeRef) { + return this.jsonMapper.convertValue(data, typeRef); } /** @@ -586,11 +622,9 @@ public static class Builder { private final String baseUri; - private ObjectMapper objectMapper; + private McpJsonMapper jsonMapper; - private HttpClient.Builder clientBuilder = HttpClient.newBuilder() - .version(HttpClient.Version.HTTP_1_1) - .connectTimeout(Duration.ofSeconds(10)); + private HttpClient.Builder clientBuilder = HttpClient.newBuilder().version(HttpClient.Version.HTTP_1_1); private String endpoint = DEFAULT_ENDPOINT; @@ -600,7 +634,12 @@ public static class Builder { private HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(); - private AsyncHttpRequestCustomizer httpRequestCustomizer = AsyncHttpRequestCustomizer.NOOP; + private McpAsyncHttpClientRequestCustomizer httpRequestCustomizer = McpAsyncHttpClientRequestCustomizer.NOOP; + + private Duration connectTimeout = Duration.ofSeconds(10); + + private List supportedProtocolVersions = List.of(ProtocolVersions.MCP_2024_11_05, + ProtocolVersions.MCP_2025_03_26, ProtocolVersions.MCP_2025_06_18); /** * Creates a new builder with the specified base URI. @@ -656,13 +695,13 @@ public Builder customizeRequest(final Consumer requestCusto } /** - * Configure the {@link ObjectMapper} to use. - * @param objectMapper instance to use + * Configure a custom {@link McpJsonMapper} implementation to use. + * @param jsonMapper instance to use * @return the builder instance */ - public Builder objectMapper(ObjectMapper objectMapper) { - Assert.notNull(objectMapper, "ObjectMapper must not be null"); - this.objectMapper = objectMapper; + public Builder jsonMapper(McpJsonMapper jsonMapper) { + Assert.notNull(jsonMapper, "jsonMapper must not be null"); + this.jsonMapper = jsonMapper; return this; } @@ -709,16 +748,17 @@ public Builder openConnectionOnStartup(boolean openConnectionOnStartup) { * executing them. *

* This overrides the customizer from - * {@link #asyncHttpRequestCustomizer(AsyncHttpRequestCustomizer)}. + * {@link #asyncHttpRequestCustomizer(McpAsyncHttpClientRequestCustomizer)}. *

- * Do NOT use a blocking {@link SyncHttpRequestCustomizer} in a non-blocking - * context. Use {@link #asyncHttpRequestCustomizer(AsyncHttpRequestCustomizer)} + * Do NOT use a blocking {@link McpSyncHttpClientRequestCustomizer} in a + * non-blocking context. Use + * {@link #asyncHttpRequestCustomizer(McpAsyncHttpClientRequestCustomizer)} * instead. * @param syncHttpRequestCustomizer the request customizer * @return this builder */ - public Builder httpRequestCustomizer(SyncHttpRequestCustomizer syncHttpRequestCustomizer) { - this.httpRequestCustomizer = AsyncHttpRequestCustomizer.fromSync(syncHttpRequestCustomizer); + public Builder httpRequestCustomizer(McpSyncHttpClientRequestCustomizer syncHttpRequestCustomizer) { + this.httpRequestCustomizer = McpAsyncHttpClientRequestCustomizer.fromSync(syncHttpRequestCustomizer); return this; } @@ -727,27 +767,62 @@ public Builder httpRequestCustomizer(SyncHttpRequestCustomizer syncHttpRequestCu * executing them. *

* This overrides the customizer from - * {@link #httpRequestCustomizer(SyncHttpRequestCustomizer)}. + * {@link #httpRequestCustomizer(McpSyncHttpClientRequestCustomizer)}. *

* Do NOT use a blocking implementation in a non-blocking context. * @param asyncHttpRequestCustomizer the request customizer * @return this builder */ - public Builder asyncHttpRequestCustomizer(AsyncHttpRequestCustomizer asyncHttpRequestCustomizer) { + public Builder asyncHttpRequestCustomizer(McpAsyncHttpClientRequestCustomizer asyncHttpRequestCustomizer) { this.httpRequestCustomizer = asyncHttpRequestCustomizer; return this; } + /** + * Sets the connection timeout for the HTTP client. + * @param connectTimeout the connection timeout duration + * @return this builder + */ + public Builder connectTimeout(Duration connectTimeout) { + Assert.notNull(connectTimeout, "connectTimeout must not be null"); + this.connectTimeout = connectTimeout; + return this; + } + + /** + * Sets the list of supported protocol versions used in version negotiation. By + * default, the client will send the latest of those versions in the + * {@code MCP-Protocol-Version} header. + *

+ * Setting this value only updates the values used in version negotiation, and + * does NOT impact the actual capabilities of the transport. It should only be + * used for compatibility with servers having strict requirements around the + * {@code MCP-Protocol-Version} header. + * @param supportedProtocolVersions protocol versions supported by this transport + * @return this builder + * @see version + * negotiation specification + * @see Protocol + * Version Header + */ + public Builder supportedProtocolVersions(List supportedProtocolVersions) { + Assert.notEmpty(supportedProtocolVersions, "supportedProtocolVersions must not be empty"); + this.supportedProtocolVersions = Collections.unmodifiableList(supportedProtocolVersions); + return this; + } + /** * Construct a fresh instance of {@link HttpClientStreamableHttpTransport} using * the current builder configuration. * @return a new instance of {@link HttpClientStreamableHttpTransport} */ public HttpClientStreamableHttpTransport build() { - ObjectMapper objectMapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); - - return new HttpClientStreamableHttpTransport(objectMapper, clientBuilder.build(), requestBuilder, baseUri, - endpoint, resumableStreams, openConnectionOnStartup, httpRequestCustomizer); + HttpClient httpClient = this.clientBuilder.connectTimeout(this.connectTimeout).build(); + return new HttpClientStreamableHttpTransport(jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, + httpClient, requestBuilder, baseUri, endpoint, resumableStreams, openConnectionOnStartup, + httpRequestCustomizer, supportedProtocolVersions); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/ResponseSubscribers.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/ResponseSubscribers.java similarity index 89% rename from mcp/src/main/java/io/modelcontextprotocol/client/transport/ResponseSubscribers.java rename to mcp-core/src/main/java/io/modelcontextprotocol/client/transport/ResponseSubscribers.java index 296d1a17d..29dc23c35 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/ResponseSubscribers.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/ResponseSubscribers.java @@ -141,7 +141,6 @@ protected void hookOnSubscribe(Subscription subscription) { @Override protected void hookOnNext(String line) { - if (line.isEmpty()) { // Empty line means end of event if (this.eventBuilder.length() > 0) { @@ -158,23 +157,27 @@ protected void hookOnNext(String line) { if (matcher.find()) { this.eventBuilder.append(matcher.group(1).trim()).append("\n"); } + upstream().request(1); } else if (line.startsWith("id:")) { var matcher = EVENT_ID_PATTERN.matcher(line); if (matcher.find()) { this.currentEventId.set(matcher.group(1).trim()); } + upstream().request(1); } else if (line.startsWith("event:")) { var matcher = EVENT_TYPE_PATTERN.matcher(line); if (matcher.find()) { this.currentEventType.set(matcher.group(1).trim()); } + upstream().request(1); } else if (line.startsWith(":")) { // Ignore comment lines starting with ":" // This is a no-op, just to skip comments logger.debug("Ignoring comment line: {}", line); + upstream().request(1); } else { // If the response is not successful, emit an error @@ -220,6 +223,8 @@ static class AggregateSubscriber extends BaseSubscriber { */ private ResponseInfo responseInfo; + volatile boolean hasRequestedDemand = false; + /** * Creates a new JsonLineSubscriber that will emit parsed JSON-RPC messages. * @param sink the {@link FluxSink} to emit parsed {@link ResponseEvent} objects @@ -233,7 +238,13 @@ public AggregateSubscriber(ResponseInfo responseInfo, FluxSink si @Override protected void hookOnSubscribe(Subscription subscription) { - sink.onRequest(subscription::request); + + sink.onRequest(n -> { + if (!hasRequestedDemand) { + subscription.request(Long.MAX_VALUE); + } + hasRequestedDemand = true; + }); // Register disposal callback to cancel subscription when Flux is disposed sink.onDispose(subscription::cancel); @@ -246,8 +257,11 @@ protected void hookOnNext(String line) { @Override protected void hookOnComplete() { - String data = this.eventBuilder.toString(); - this.sink.next(new AggregateResponseEvent(responseInfo, data)); + + if (hasRequestedDemand) { + String data = this.eventBuilder.toString(); + this.sink.next(new AggregateResponseEvent(responseInfo, data)); + } this.sink.complete(); } @@ -268,6 +282,8 @@ static class BodilessResponseLineSubscriber extends BaseSubscriber { private final ResponseInfo responseInfo; + volatile boolean hasRequestedDemand = false; + public BodilessResponseLineSubscriber(ResponseInfo responseInfo, FluxSink sink) { this.sink = sink; this.responseInfo = responseInfo; @@ -277,7 +293,10 @@ public BodilessResponseLineSubscriber(ResponseInfo responseInfo, FluxSink { - subscription.request(n); + if (!hasRequestedDemand) { + subscription.request(Long.MAX_VALUE); + } + hasRequestedDemand = true; }); // Register disposal callback to cancel subscription when Flux is disposed @@ -288,11 +307,13 @@ protected void hookOnSubscribe(Subscription subscription) { @Override protected void hookOnComplete() { - // emit dummy event to be able to inspect the response info - // this is a shortcut allowing for a more streamlined processing using - // operator composition instead of having to deal with the CompletableFuture - // along the Subscriber for inspecting the result - this.sink.next(new DummyEvent(responseInfo)); + if (hasRequestedDemand) { + // emit dummy event to be able to inspect the response info + // this is a shortcut allowing for a more streamlined processing using + // operator composition instead of having to deal with the + // CompletableFuture along the Subscriber for inspecting the result + this.sink.next(new DummyEvent(responseInfo)); + } this.sink.complete(); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/ServerParameters.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/ServerParameters.java similarity index 100% rename from mcp/src/main/java/io/modelcontextprotocol/client/transport/ServerParameters.java rename to mcp-core/src/main/java/io/modelcontextprotocol/client/transport/ServerParameters.java diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java similarity index 92% rename from mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java rename to mcp-core/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java index 009d415e0..1b4eaca97 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java @@ -15,8 +15,8 @@ import java.util.function.Consumer; import java.util.function.Function; -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.json.TypeRef; +import io.modelcontextprotocol.json.McpJsonMapper; import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; @@ -48,7 +48,7 @@ public class StdioClientTransport implements McpClientTransport { /** The server process being communicated with */ private Process process; - private ObjectMapper objectMapper; + private McpJsonMapper jsonMapper; /** Scheduler for handling inbound messages from the server process */ private Scheduler inboundScheduler; @@ -70,29 +70,20 @@ public class StdioClientTransport implements McpClientTransport { private Consumer stdErrorHandler = error -> logger.info("STDERR Message received: {}", error); /** - * Creates a new StdioClientTransport with the specified parameters and default - * ObjectMapper. + * Creates a new StdioClientTransport with the specified parameters and JsonMapper. * @param params The parameters for configuring the server process + * @param jsonMapper The JsonMapper to use for JSON serialization/deserialization */ - public StdioClientTransport(ServerParameters params) { - this(params, new ObjectMapper()); - } - - /** - * Creates a new StdioClientTransport with the specified parameters and ObjectMapper. - * @param params The parameters for configuring the server process - * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization - */ - public StdioClientTransport(ServerParameters params, ObjectMapper objectMapper) { + public StdioClientTransport(ServerParameters params, McpJsonMapper jsonMapper) { Assert.notNull(params, "The params can not be null"); - Assert.notNull(objectMapper, "The ObjectMapper can not be null"); + Assert.notNull(jsonMapper, "The JsonMapper can not be null"); this.inboundSink = Sinks.many().unicast().onBackpressureBuffer(); this.outboundSink = Sinks.many().unicast().onBackpressureBuffer(); this.params = params; - this.objectMapper = objectMapper; + this.jsonMapper = jsonMapper; this.errorSink = Sinks.many().unicast().onBackpressureBuffer(); @@ -259,7 +250,7 @@ private void startInboundProcessing() { String line; while (!isClosing && (line = processReader.readLine()) != null) { try { - JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.objectMapper, line); + JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.jsonMapper, line); if (!this.inboundSink.tryEmitNext(message).isSuccess()) { if (!isClosing) { logger.error("Failed to enqueue inbound message: {}", message); @@ -300,7 +291,7 @@ private void startOutboundProcessing() { .handle((message, s) -> { if (message != null && !isClosing) { try { - String jsonMessage = objectMapper.writeValueAsString(message); + String jsonMessage = jsonMapper.writeValueAsString(message); // Escape any embedded newlines in the JSON message as per spec: // https://spec.modelcontextprotocol.io/specification/basic/transports/#stdio // - Messages are delimited by newlines, and MUST NOT contain @@ -392,8 +383,8 @@ public Sinks.Many getErrorSink() { } @Override - public T unmarshalFrom(Object data, TypeReference typeRef) { - return this.objectMapper.convertValue(data, typeRef); + public T unmarshalFrom(Object data, TypeRef typeRef) { + return this.jsonMapper.convertValue(data, typeRef); } } diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/customizer/DelegatingMcpAsyncHttpClientRequestCustomizer.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/customizer/DelegatingMcpAsyncHttpClientRequestCustomizer.java new file mode 100644 index 000000000..2492efe18 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/customizer/DelegatingMcpAsyncHttpClientRequestCustomizer.java @@ -0,0 +1,42 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ +package io.modelcontextprotocol.client.transport.customizer; + +import java.net.URI; +import java.net.http.HttpRequest; +import java.util.List; + +import org.reactivestreams.Publisher; + +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.util.Assert; + +import reactor.core.publisher.Mono; + +/** + * Composable {@link McpAsyncHttpClientRequestCustomizer} that applies multiple + * customizers, in order. + * + * @author Daniel Garnier-Moiroux + */ +public class DelegatingMcpAsyncHttpClientRequestCustomizer implements McpAsyncHttpClientRequestCustomizer { + + private final List customizers; + + public DelegatingMcpAsyncHttpClientRequestCustomizer(List customizers) { + Assert.notNull(customizers, "Customizers must not be null"); + this.customizers = customizers; + } + + @Override + public Publisher customize(HttpRequest.Builder builder, String method, URI endpoint, + String body, McpTransportContext context) { + var result = Mono.just(builder); + for (var customizer : this.customizers) { + result = result.flatMap(b -> Mono.from(customizer.customize(b, method, endpoint, body, context))); + } + return result; + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/customizer/DelegatingMcpSyncHttpClientRequestCustomizer.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/customizer/DelegatingMcpSyncHttpClientRequestCustomizer.java new file mode 100644 index 000000000..e627e7e69 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/customizer/DelegatingMcpSyncHttpClientRequestCustomizer.java @@ -0,0 +1,35 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.client.transport.customizer; + +import java.net.URI; +import java.net.http.HttpRequest; +import java.util.List; + +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.util.Assert; + +/** + * Composable {@link McpSyncHttpClientRequestCustomizer} that applies multiple + * customizers, in order. + * + * @author Daniel Garnier-Moiroux + */ +public class DelegatingMcpSyncHttpClientRequestCustomizer implements McpSyncHttpClientRequestCustomizer { + + private final List delegates; + + public DelegatingMcpSyncHttpClientRequestCustomizer(List customizers) { + Assert.notNull(customizers, "Customizers must not be null"); + this.delegates = customizers; + } + + @Override + public void customize(HttpRequest.Builder builder, String method, URI endpoint, String body, + McpTransportContext context) { + this.delegates.forEach(delegate -> delegate.customize(builder, method, endpoint, body, context)); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/AsyncHttpRequestCustomizer.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/customizer/McpAsyncHttpClientRequestCustomizer.java similarity index 62% rename from mcp/src/main/java/io/modelcontextprotocol/client/transport/AsyncHttpRequestCustomizer.java rename to mcp-core/src/main/java/io/modelcontextprotocol/client/transport/customizer/McpAsyncHttpClientRequestCustomizer.java index dee026d96..756b39c35 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/AsyncHttpRequestCustomizer.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/customizer/McpAsyncHttpClientRequestCustomizer.java @@ -2,15 +2,18 @@ * Copyright 2024-2025 the original author or authors. */ -package io.modelcontextprotocol.client.transport; +package io.modelcontextprotocol.client.transport.customizer; import java.net.URI; import java.net.http.HttpRequest; + import org.reactivestreams.Publisher; import reactor.core.publisher.Mono; import reactor.core.scheduler.Schedulers; import reactor.util.annotation.Nullable; +import io.modelcontextprotocol.common.McpTransportContext; + /** * Customize {@link HttpRequest.Builder} before executing the request, in either SSE or * Streamable HTTP transport. @@ -19,12 +22,12 @@ * * @author Daniel Garnier-Moiroux */ -public interface AsyncHttpRequestCustomizer { +public interface McpAsyncHttpClientRequestCustomizer { Publisher customize(HttpRequest.Builder builder, String method, URI endpoint, - @Nullable String body); + @Nullable String body, McpTransportContext context); - AsyncHttpRequestCustomizer NOOP = new Noop(); + McpAsyncHttpClientRequestCustomizer NOOP = new Noop(); /** * Wrap a sync implementation in an async wrapper. @@ -32,18 +35,18 @@ Publisher customize(HttpRequest.Builder builder, String met * Do NOT wrap a blocking implementation for use in a non-blocking context. For a * blocking implementation, consider using {@link Schedulers#boundedElastic()}. */ - static AsyncHttpRequestCustomizer fromSync(SyncHttpRequestCustomizer customizer) { - return (builder, method, uri, body) -> Mono.fromSupplier(() -> { - customizer.customize(builder, method, uri, body); + static McpAsyncHttpClientRequestCustomizer fromSync(McpSyncHttpClientRequestCustomizer customizer) { + return (builder, method, uri, body, context) -> Mono.fromSupplier(() -> { + customizer.customize(builder, method, uri, body, context); return builder; }); } - class Noop implements AsyncHttpRequestCustomizer { + class Noop implements McpAsyncHttpClientRequestCustomizer { @Override public Publisher customize(HttpRequest.Builder builder, String method, URI endpoint, - String body) { + String body, McpTransportContext context) { return Mono.just(builder); } diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/customizer/McpSyncHttpClientRequestCustomizer.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/customizer/McpSyncHttpClientRequestCustomizer.java new file mode 100644 index 000000000..e22e3aa62 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/customizer/McpSyncHttpClientRequestCustomizer.java @@ -0,0 +1,28 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.client.transport.customizer; + +import java.net.URI; +import java.net.http.HttpRequest; + +import reactor.util.annotation.Nullable; + +import io.modelcontextprotocol.client.McpClient.SyncSpec; +import io.modelcontextprotocol.common.McpTransportContext; + +/** + * Customize {@link HttpRequest.Builder} before executing the request, either in SSE or + * Streamable HTTP transport. Do not rely on thread-locals in this implementation, instead + * use {@link SyncSpec#transportContextProvider} to extract context, and then consume it + * through {@link McpTransportContext}. + * + * @author Daniel Garnier-Moiroux + */ +public interface McpSyncHttpClientRequestCustomizer { + + void customize(HttpRequest.Builder builder, String method, URI endpoint, @Nullable String body, + McpTransportContext context); + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/common/DefaultMcpTransportContext.java b/mcp-core/src/main/java/io/modelcontextprotocol/common/DefaultMcpTransportContext.java new file mode 100644 index 000000000..cde637b15 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/common/DefaultMcpTransportContext.java @@ -0,0 +1,45 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.common; + +import java.util.Map; + +import io.modelcontextprotocol.util.Assert; + +/** + * Default implementation for {@link McpTransportContext} which uses a map as storage. + * + * @author Dariusz Jędrzejczyk + * @author Daniel Garnier-Moiroux + */ +class DefaultMcpTransportContext implements McpTransportContext { + + private final Map metadata; + + DefaultMcpTransportContext(Map metadata) { + Assert.notNull(metadata, "The metadata cannot be null"); + this.metadata = metadata; + } + + @Override + public Object get(String key) { + return this.metadata.get(key); + } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) + return false; + + DefaultMcpTransportContext that = (DefaultMcpTransportContext) o; + return this.metadata.equals(that.metadata); + } + + @Override + public int hashCode() { + return this.metadata.hashCode(); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpTransportContext.java b/mcp-core/src/main/java/io/modelcontextprotocol/common/McpTransportContext.java similarity index 68% rename from mcp/src/main/java/io/modelcontextprotocol/server/McpTransportContext.java rename to mcp-core/src/main/java/io/modelcontextprotocol/common/McpTransportContext.java index 1cd540f72..46a2ccf84 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpTransportContext.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/common/McpTransportContext.java @@ -2,9 +2,10 @@ * Copyright 2024-2025 the original author or authors. */ -package io.modelcontextprotocol.server; +package io.modelcontextprotocol.common; import java.util.Collections; +import java.util.Map; /** * Context associated with the transport layer. It allows to add transport-level metadata @@ -26,6 +27,15 @@ public interface McpTransportContext { @SuppressWarnings("unchecked") McpTransportContext EMPTY = new DefaultMcpTransportContext(Collections.EMPTY_MAP); + /** + * Create an unmodifiable context containing the given metadata. + * @param metadata the transport metadata + * @return the context containing the metadata + */ + static McpTransportContext create(Map metadata) { + return new DefaultMcpTransportContext(metadata); + } + /** * Extract a value from the context. * @param key the key under the data is expected @@ -33,18 +43,4 @@ public interface McpTransportContext { */ Object get(String key); - /** - * Inserts a value for a given key. - * @param key a String representing the key - * @param value the value to store - */ - void put(String key, Object value); - - /** - * Copies the contents of the context to allow further modifications without affecting - * the initial object. - * @return a new instance with the underlying storage copied. - */ - McpTransportContext copy(); - } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/DefaultMcpStatelessServerHandler.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/DefaultMcpStatelessServerHandler.java similarity index 97% rename from mcp/src/main/java/io/modelcontextprotocol/server/DefaultMcpStatelessServerHandler.java rename to mcp-core/src/main/java/io/modelcontextprotocol/server/DefaultMcpStatelessServerHandler.java index 2df3514b6..d1b55f594 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/DefaultMcpStatelessServerHandler.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/DefaultMcpStatelessServerHandler.java @@ -4,6 +4,7 @@ package io.modelcontextprotocol.server; +import io.modelcontextprotocol.common.McpTransportContext; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import org.slf4j.Logger; diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java similarity index 68% rename from mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java rename to mcp-core/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java index a51c2e36c..23285d514 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -5,7 +5,6 @@ package io.modelcontextprotocol.server; import java.time.Duration; -import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -15,34 +14,37 @@ import java.util.concurrent.CopyOnWriteArrayList; import java.util.function.BiFunction; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.TypeRef; +import io.modelcontextprotocol.json.schema.JsonSchemaValidator; import io.modelcontextprotocol.spec.DefaultMcpStreamableServerSessionFactory; -import io.modelcontextprotocol.spec.McpServerTransportProviderBase; -import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; - -import io.modelcontextprotocol.spec.JsonSchemaValidator; import io.modelcontextprotocol.spec.McpClientSession; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.CompleteResult.CompleteCompletion; +import io.modelcontextprotocol.spec.McpSchema.ErrorCodes; import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; -import io.modelcontextprotocol.spec.McpSchema.ResourceTemplate; +import io.modelcontextprotocol.spec.McpSchema.PromptReference; +import io.modelcontextprotocol.spec.McpSchema.ResourceReference; import io.modelcontextprotocol.spec.McpSchema.SetLevelRequest; import io.modelcontextprotocol.spec.McpSchema.Tool; import io.modelcontextprotocol.spec.McpServerSession; import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.spec.McpServerTransportProviderBase; +import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; import io.modelcontextprotocol.util.Assert; -import io.modelcontextprotocol.util.DeafaultMcpUriTemplateManagerFactory; +import io.modelcontextprotocol.util.DefaultMcpUriTemplateManagerFactory; import io.modelcontextprotocol.util.McpUriTemplateManagerFactory; import io.modelcontextprotocol.util.Utils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import static io.modelcontextprotocol.spec.McpError.RESOURCE_NOT_FOUND; + /** * The Model Context Protocol (MCP) server implementation that provides asynchronous * communication using Project Reactor's Mono and Flux types. @@ -91,7 +93,7 @@ public class McpAsyncServer { private final McpServerTransportProviderBase mcpTransportProvider; - private final ObjectMapper objectMapper; + private final McpJsonMapper jsonMapper; private final JsonSchemaValidator jsonSchemaValidator; @@ -103,10 +105,10 @@ public class McpAsyncServer { private final CopyOnWriteArrayList tools = new CopyOnWriteArrayList<>(); - private final CopyOnWriteArrayList resourceTemplates = new CopyOnWriteArrayList<>(); - private final ConcurrentHashMap resources = new ConcurrentHashMap<>(); + private final ConcurrentHashMap resourceTemplates = new ConcurrentHashMap<>(); + private final ConcurrentHashMap prompts = new ConcurrentHashMap<>(); // FIXME: this field is deprecated and should be remvoed together with the @@ -117,26 +119,26 @@ public class McpAsyncServer { private List protocolVersions; - private McpUriTemplateManagerFactory uriTemplateManagerFactory = new DeafaultMcpUriTemplateManagerFactory(); + private McpUriTemplateManagerFactory uriTemplateManagerFactory = new DefaultMcpUriTemplateManagerFactory(); /** * Create a new McpAsyncServer with the given transport provider and capabilities. * @param mcpTransportProvider The transport layer implementation for MCP * communication. * @param features The MCP server supported features. - * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + * @param jsonMapper The JsonMapper to use for JSON serialization/deserialization */ - McpAsyncServer(McpServerTransportProvider mcpTransportProvider, ObjectMapper objectMapper, + McpAsyncServer(McpServerTransportProvider mcpTransportProvider, McpJsonMapper jsonMapper, McpServerFeatures.Async features, Duration requestTimeout, McpUriTemplateManagerFactory uriTemplateManagerFactory, JsonSchemaValidator jsonSchemaValidator) { this.mcpTransportProvider = mcpTransportProvider; - this.objectMapper = objectMapper; + this.jsonMapper = jsonMapper; this.serverInfo = features.serverInfo(); this.serverCapabilities = features.serverCapabilities().mutate().logging().build(); this.instructions = features.instructions(); this.tools.addAll(withStructuredOutputHandling(jsonSchemaValidator, features.tools())); this.resources.putAll(features.resources()); - this.resourceTemplates.addAll(features.resourceTemplates()); + this.resourceTemplates.putAll(features.resourceTemplates()); this.prompts.putAll(features.prompts()); this.completions.putAll(features.completions()); this.uriTemplateManagerFactory = uriTemplateManagerFactory; @@ -151,17 +153,17 @@ public class McpAsyncServer { requestTimeout, transport, this::asyncInitializeRequestHandler, requestHandlers, notificationHandlers)); } - McpAsyncServer(McpStreamableServerTransportProvider mcpTransportProvider, ObjectMapper objectMapper, + McpAsyncServer(McpStreamableServerTransportProvider mcpTransportProvider, McpJsonMapper jsonMapper, McpServerFeatures.Async features, Duration requestTimeout, McpUriTemplateManagerFactory uriTemplateManagerFactory, JsonSchemaValidator jsonSchemaValidator) { this.mcpTransportProvider = mcpTransportProvider; - this.objectMapper = objectMapper; + this.jsonMapper = jsonMapper; this.serverInfo = features.serverInfo(); this.serverCapabilities = features.serverCapabilities().mutate().logging().build(); this.instructions = features.instructions(); this.tools.addAll(withStructuredOutputHandling(jsonSchemaValidator, features.tools())); this.resources.putAll(features.resources()); - this.resourceTemplates.addAll(features.resourceTemplates()); + this.resourceTemplates.putAll(features.resourceTemplates()); this.prompts.putAll(features.prompts()); this.completions.putAll(features.completions()); this.uriTemplateManagerFactory = uriTemplateManagerFactory; @@ -319,25 +321,24 @@ private McpNotificationHandler asyncRootsListChangedNotificationHandler( */ public Mono addTool(McpServerFeatures.AsyncToolSpecification toolSpecification) { if (toolSpecification == null) { - return Mono.error(new McpError("Tool specification must not be null")); + return Mono.error(new IllegalArgumentException("Tool specification must not be null")); } if (toolSpecification.tool() == null) { - return Mono.error(new McpError("Tool must not be null")); + return Mono.error(new IllegalArgumentException("Tool must not be null")); } if (toolSpecification.call() == null && toolSpecification.callHandler() == null) { - return Mono.error(new McpError("Tool call handler must not be null")); + return Mono.error(new IllegalArgumentException("Tool call handler must not be null")); } if (this.serverCapabilities.tools() == null) { - return Mono.error(new McpError("Server must be configured with tool capabilities")); + return Mono.error(new IllegalStateException("Server must be configured with tool capabilities")); } var wrappedToolSpecification = withStructuredOutputHandling(this.jsonSchemaValidator, toolSpecification); return Mono.defer(() -> { - // Check for duplicate tool names - if (this.tools.stream().anyMatch(th -> th.tool().name().equals(wrappedToolSpecification.tool().name()))) { - return Mono.error( - new McpError("Tool with name '" + wrappedToolSpecification.tool().name() + "' already exists")); + // Remove tools with duplicate tool names first + if (this.tools.removeIf(th -> th.tool().name().equals(wrappedToolSpecification.tool().name()))) { + logger.warn("Replace existing Tool with name '{}'", wrappedToolSpecification.tool().name()); } this.tools.add(wrappedToolSpecification); @@ -376,6 +377,11 @@ public Mono apply(McpAsyncServerExchange exchange, McpSchema.Cal return this.delegateCallToolResult.apply(exchange, request).map(result -> { + if (Boolean.TRUE.equals(result.isError())) { + // If the tool call resulted in an error, skip further validation + return result; + } + if (outputSchema == null) { if (result.structuredContent() != null) { logger.warn( @@ -391,11 +397,12 @@ public Mono apply(McpAsyncServerExchange exchange, McpSchema.Cal // results that conform to this schema. // https://modelcontextprotocol.io/specification/2025-06-18/server/tools#output-schema if (result.structuredContent() == null) { - logger.warn( - "Response missing structured content which is expected when calling tool with non-empty outputSchema"); - return new CallToolResult( - "Response missing structured content which is expected when calling tool with non-empty outputSchema", - true); + String content = "Response missing structured content which is expected when calling tool with non-empty outputSchema"; + logger.warn(content); + return CallToolResult.builder() + .content(List.of(new McpSchema.TextContent(content))) + .isError(true) + .build(); } // Validate the result against the output schema @@ -403,7 +410,10 @@ public Mono apply(McpAsyncServerExchange exchange, McpSchema.Cal if (!validation.valid()) { logger.warn("Tool call result validation failed: {}", validation.errorMessage()); - return new CallToolResult(validation.errorMessage(), true); + return CallToolResult.builder() + .content(List.of(new McpSchema.TextContent(validation.errorMessage()))) + .isError(true) + .build(); } if (Utils.isEmpty(result.content())) { @@ -413,8 +423,11 @@ public Mono apply(McpAsyncServerExchange exchange, McpSchema.Cal // TextContent block.) // https://modelcontextprotocol.io/specification/2025-06-18/server/tools#structured-content - return new CallToolResult(List.of(new McpSchema.TextContent(validation.jsonStructuredOutput())), - result.isError(), result.structuredContent()); + return CallToolResult.builder() + .content(List.of(new McpSchema.TextContent(validation.jsonStructuredOutput()))) + .isError(result.isError()) + .structuredContent(result.structuredContent()) + .build(); } return result; @@ -453,6 +466,14 @@ private static McpServerFeatures.AsyncToolSpecification withStructuredOutputHand .build(); } + /** + * List all registered tools. + * @return A Flux stream of all registered tools + */ + public Flux listTools() { + return Flux.fromIterable(this.tools).map(McpServerFeatures.AsyncToolSpecification::tool); + } + /** * Remove a tool handler at runtime. * @param toolName The name of the tool handler to remove @@ -460,23 +481,25 @@ private static McpServerFeatures.AsyncToolSpecification withStructuredOutputHand */ public Mono removeTool(String toolName) { if (toolName == null) { - return Mono.error(new McpError("Tool name must not be null")); + return Mono.error(new IllegalArgumentException("Tool name must not be null")); } if (this.serverCapabilities.tools() == null) { - return Mono.error(new McpError("Server must be configured with tool capabilities")); + return Mono.error(new IllegalStateException("Server must be configured with tool capabilities")); } return Mono.defer(() -> { - boolean removed = this.tools - .removeIf(toolSpecification -> toolSpecification.tool().name().equals(toolName)); - if (removed) { + if (this.tools.removeIf(toolSpecification -> toolSpecification.tool().name().equals(toolName))) { + logger.debug("Removed tool handler: {}", toolName); if (this.serverCapabilities.tools().listChanged()) { return notifyToolsListChanged(); } - return Mono.empty(); } - return Mono.error(new McpError("Tool with name '" + toolName + "' not found")); + else { + logger.warn("Ignore as a Tool with name '{}' not found", toolName); + } + + return Mono.empty(); }); } @@ -498,8 +521,8 @@ private McpRequestHandler toolsListRequestHandler() { private McpRequestHandler toolsCallRequestHandler() { return (exchange, params) -> { - McpSchema.CallToolRequest callToolRequest = objectMapper.convertValue(params, - new TypeReference() { + McpSchema.CallToolRequest callToolRequest = jsonMapper.convertValue(params, + new TypeRef() { }); Optional toolSpecification = this.tools.stream() @@ -507,11 +530,13 @@ private McpRequestHandler toolsCallRequestHandler() { .findAny(); if (toolSpecification.isEmpty()) { - return Mono.error(new McpError("Tool not found: " + callToolRequest.name())); + return Mono.error(McpError.builder(McpSchema.ErrorCodes.INVALID_PARAMS) + .message("Unknown tool: invalid_tool_name") + .data("Tool not found: " + callToolRequest.name()) + .build()); } - return toolSpecification.map(tool -> Mono.defer(() -> tool.callHandler().apply(exchange, callToolRequest))) - .orElse(Mono.error(new McpError("Tool not found: " + callToolRequest.name()))); + return toolSpecification.get().callHandler().apply(exchange, callToolRequest); }; } @@ -526,19 +551,22 @@ private McpRequestHandler toolsCallRequestHandler() { */ public Mono addResource(McpServerFeatures.AsyncResourceSpecification resourceSpecification) { if (resourceSpecification == null || resourceSpecification.resource() == null) { - return Mono.error(new McpError("Resource must not be null")); + return Mono.error(new IllegalArgumentException("Resource must not be null")); } if (this.serverCapabilities.resources() == null) { - return Mono.error(new McpError("Server must be configured with resource capabilities")); + return Mono.error(new IllegalStateException( + "Server must be configured with resource capabilities to allow adding resources")); } return Mono.defer(() -> { - if (this.resources.putIfAbsent(resourceSpecification.resource().uri(), resourceSpecification) != null) { - return Mono.error(new McpError( - "Resource with URI '" + resourceSpecification.resource().uri() + "' already exists")); + var previous = this.resources.put(resourceSpecification.resource().uri(), resourceSpecification); + if (previous != null) { + logger.warn("Replace existing Resource with URI '{}'", resourceSpecification.resource().uri()); + } + else { + logger.debug("Added resource handler: {}", resourceSpecification.resource().uri()); } - logger.debug("Added resource handler: {}", resourceSpecification.resource().uri()); if (this.serverCapabilities.resources().listChanged()) { return notifyResourcesListChanged(); } @@ -546,6 +574,14 @@ public Mono addResource(McpServerFeatures.AsyncResourceSpecification resou }); } + /** + * List all registered resources. + * @return A Flux stream of all registered resources + */ + public Flux listResources() { + return Flux.fromIterable(this.resources.values()).map(McpServerFeatures.AsyncResourceSpecification::resource); + } + /** * Remove a resource handler at runtime. * @param resourceUri The URI of the resource handler to remove @@ -553,10 +589,11 @@ public Mono addResource(McpServerFeatures.AsyncResourceSpecification resou */ public Mono removeResource(String resourceUri) { if (resourceUri == null) { - return Mono.error(new McpError("Resource URI must not be null")); + return Mono.error(new IllegalArgumentException("Resource URI must not be null")); } if (this.serverCapabilities.resources() == null) { - return Mono.error(new McpError("Server must be configured with resource capabilities")); + return Mono.error(new IllegalStateException( + "Server must be configured with resource capabilities to allow removing resources")); } return Mono.defer(() -> { @@ -568,7 +605,74 @@ public Mono removeResource(String resourceUri) { } return Mono.empty(); } - return Mono.error(new McpError("Resource with URI '" + resourceUri + "' not found")); + else { + logger.warn("Ignore as a Resource with URI '{}' not found", resourceUri); + } + return Mono.empty(); + }); + } + + /** + * Add a new resource template at runtime. + * @param resourceTemplateSpecification The resource template to add + * @return Mono that completes when clients have been notified of the change + */ + public Mono addResourceTemplate( + McpServerFeatures.AsyncResourceTemplateSpecification resourceTemplateSpecification) { + + if (this.serverCapabilities.resources() == null) { + return Mono.error(new IllegalStateException( + "Server must be configured with resource capabilities to allow adding resource templates")); + } + + return Mono.defer(() -> { + var previous = this.resourceTemplates.put(resourceTemplateSpecification.resourceTemplate().uriTemplate(), + resourceTemplateSpecification); + if (previous != null) { + logger.warn("Replace existing Resource Template with URI '{}'", + resourceTemplateSpecification.resourceTemplate().uriTemplate()); + } + else { + logger.debug("Added resource template handler: {}", + resourceTemplateSpecification.resourceTemplate().uriTemplate()); + } + if (this.serverCapabilities.resources().listChanged()) { + return notifyResourcesListChanged(); + } + return Mono.empty(); + }); + } + + /** + * List all registered resource templates. + * @return A Flux stream of all registered resource templates + */ + public Flux listResourceTemplates() { + return Flux.fromIterable(this.resourceTemplates.values()) + .map(McpServerFeatures.AsyncResourceTemplateSpecification::resourceTemplate); + } + + /** + * Remove a resource template at runtime. + * @param uriTemplate The URI template of the resource template to remove + * @return Mono that completes when clients have been notified of the change + */ + public Mono removeResourceTemplate(String uriTemplate) { + + if (this.serverCapabilities.resources() == null) { + return Mono.error(new IllegalStateException( + "Server must be configured with resource capabilities to allow removing resource templates")); + } + + return Mono.defer(() -> { + McpServerFeatures.AsyncResourceTemplateSpecification removed = this.resourceTemplates.remove(uriTemplate); + if (removed != null) { + logger.debug("Removed resource template: {}", uriTemplate); + } + else { + logger.warn("Ignore as a Resource Template with URI '{}' not found", uriTemplate); + } + return Mono.empty(); }); } @@ -600,46 +704,50 @@ private McpRequestHandler resourcesListRequestHan } private McpRequestHandler resourceTemplateListRequestHandler() { - return (exchange, params) -> Mono - .just(new McpSchema.ListResourceTemplatesResult(this.getResourceTemplates(), null)); - + return (exchange, params) -> { + var resourceList = this.resourceTemplates.values() + .stream() + .map(McpServerFeatures.AsyncResourceTemplateSpecification::resourceTemplate) + .toList(); + return Mono.just(new McpSchema.ListResourceTemplatesResult(resourceList, null)); + }; } - private List getResourceTemplates() { - var list = new ArrayList<>(this.resourceTemplates); - List resourceTemplates = this.resources.keySet() - .stream() - .filter(uri -> uri.contains("{")) - .map(uri -> { - var resource = this.resources.get(uri).resource(); - var template = new McpSchema.ResourceTemplate(resource.uri(), resource.name(), resource.title(), - resource.description(), resource.mimeType(), resource.annotations()); - return template; - }) - .toList(); + private McpRequestHandler resourcesReadRequestHandler() { + return (ex, params) -> { + McpSchema.ReadResourceRequest resourceRequest = jsonMapper.convertValue(params, new TypeRef<>() { + }); - list.addAll(resourceTemplates); + var resourceUri = resourceRequest.uri(); - return list; + // First try to find a static resource specification + // Static resources have exact URIs + return this.findResourceSpecification(resourceUri) + .map(spec -> spec.readHandler().apply(ex, resourceRequest)) + .orElseGet(() -> { + // If not found, try to find a dynamic resource specification + // Dynamic resources have URI templates + return this.findResourceTemplateSpecification(resourceUri) + .map(spec -> spec.readHandler().apply(ex, resourceRequest)) + .orElseGet(() -> Mono.error(RESOURCE_NOT_FOUND.apply(resourceUri))); + }); + }; } - private McpRequestHandler resourcesReadRequestHandler() { - return (exchange, params) -> { - McpSchema.ReadResourceRequest resourceRequest = objectMapper.convertValue(params, - new TypeReference() { - }); - var resourceUri = resourceRequest.uri(); - - McpServerFeatures.AsyncResourceSpecification specification = this.resources.values() - .stream() - .filter(resourceSpecification -> this.uriTemplateManagerFactory - .create(resourceSpecification.resource().uri()) - .matches(resourceUri)) - .findFirst() - .orElseThrow(() -> new McpError("Resource not found: " + resourceUri)); + private Optional findResourceSpecification(String uri) { + var result = this.resources.values() + .stream() + .filter(spec -> this.uriTemplateManagerFactory.create(spec.resource().uri()).matches(uri)) + .findFirst(); + return result; + } - return Mono.defer(() -> specification.readHandler().apply(exchange, resourceRequest)); - }; + private Optional findResourceTemplateSpecification( + String uri) { + return this.resourceTemplates.values() + .stream() + .filter(spec -> this.uriTemplateManagerFactory.create(spec.resourceTemplate().uriTemplate()).matches(uri)) + .findFirst(); } // --------------------------------------- @@ -653,32 +761,36 @@ private McpRequestHandler resourcesReadRequestHand */ public Mono addPrompt(McpServerFeatures.AsyncPromptSpecification promptSpecification) { if (promptSpecification == null) { - return Mono.error(new McpError("Prompt specification must not be null")); + return Mono.error(new IllegalArgumentException("Prompt specification must not be null")); } if (this.serverCapabilities.prompts() == null) { - return Mono.error(new McpError("Server must be configured with prompt capabilities")); + return Mono.error(new IllegalStateException("Server must be configured with prompt capabilities")); } return Mono.defer(() -> { - McpServerFeatures.AsyncPromptSpecification specification = this.prompts - .putIfAbsent(promptSpecification.prompt().name(), promptSpecification); - if (specification != null) { - return Mono.error( - new McpError("Prompt with name '" + promptSpecification.prompt().name() + "' already exists")); + var previous = this.prompts.put(promptSpecification.prompt().name(), promptSpecification); + if (previous != null) { + logger.warn("Replace existing Prompt with name '{}'", promptSpecification.prompt().name()); + } + else { + logger.debug("Added prompt handler: {}", promptSpecification.prompt().name()); } - - logger.debug("Added prompt handler: {}", promptSpecification.prompt().name()); - - // Servers that declared the listChanged capability SHOULD send a - // notification, - // when the list of available prompts changes if (this.serverCapabilities.prompts().listChanged()) { - return notifyPromptsListChanged(); + return this.notifyPromptsListChanged(); } + return Mono.empty(); }); } + /** + * List all registered prompts. + * @return A Flux stream of all registered prompts + */ + public Flux listPrompts() { + return Flux.fromIterable(this.prompts.values()).map(McpServerFeatures.AsyncPromptSpecification::prompt); + } + /** * Remove a prompt handler at runtime. * @param promptName The name of the prompt handler to remove @@ -686,10 +798,10 @@ public Mono addPrompt(McpServerFeatures.AsyncPromptSpecification promptSpe */ public Mono removePrompt(String promptName) { if (promptName == null) { - return Mono.error(new McpError("Prompt name must not be null")); + return Mono.error(new IllegalArgumentException("Prompt name must not be null")); } if (this.serverCapabilities.prompts() == null) { - return Mono.error(new McpError("Server must be configured with prompt capabilities")); + return Mono.error(new IllegalStateException("Server must be configured with prompt capabilities")); } return Mono.defer(() -> { @@ -697,14 +809,15 @@ public Mono removePrompt(String promptName) { if (removed != null) { logger.debug("Removed prompt handler: {}", promptName); - // Servers that declared the listChanged capability SHOULD send a - // notification, when the list of available prompts changes if (this.serverCapabilities.prompts().listChanged()) { return this.notifyPromptsListChanged(); } return Mono.empty(); } - return Mono.error(new McpError("Prompt with name '" + promptName + "' not found")); + else { + logger.warn("Ignore as a Prompt with name '{}' not found", promptName); + } + return Mono.empty(); }); } @@ -734,14 +847,18 @@ private McpRequestHandler promptsListRequestHandler private McpRequestHandler promptsGetRequestHandler() { return (exchange, params) -> { - McpSchema.GetPromptRequest promptRequest = objectMapper.convertValue(params, - new TypeReference() { + McpSchema.GetPromptRequest promptRequest = jsonMapper.convertValue(params, + new TypeRef() { }); // Implement prompt retrieval logic here McpServerFeatures.AsyncPromptSpecification specification = this.prompts.get(promptRequest.name()); + if (specification == null) { - return Mono.error(new McpError("Prompt not found: " + promptRequest.name())); + return Mono.error(McpError.builder(ErrorCodes.INVALID_PARAMS) + .message("Invalid prompt name") + .data("Prompt not found: " + promptRequest.name()) + .build()); } return Mono.defer(() -> specification.promptHandler().apply(exchange, promptRequest)); @@ -782,9 +899,8 @@ private McpRequestHandler setLoggerRequestHandler() { return (exchange, params) -> { return Mono.defer(() -> { - SetLevelRequest newMinLoggingLevel = objectMapper.convertValue(params, - new TypeReference() { - }); + SetLevelRequest newMinLoggingLevel = jsonMapper.convertValue(params, new TypeRef() { + }); exchange.setMinLoggingLevel(newMinLoggingLevel.level()); @@ -797,27 +913,38 @@ private McpRequestHandler setLoggerRequestHandler() { }; } + private static final Mono EMPTY_COMPLETION_RESULT = Mono + .just(new McpSchema.CompleteResult(new CompleteCompletion(List.of(), 0, false))); + private McpRequestHandler completionCompleteRequestHandler() { return (exchange, params) -> { + McpSchema.CompleteRequest request = parseCompletionParams(params); if (request.ref() == null) { - return Mono.error(new McpError("ref must not be null")); + return Mono.error( + McpError.builder(ErrorCodes.INVALID_PARAMS).message("Completion ref must not be null").build()); } if (request.ref().type() == null) { - return Mono.error(new McpError("type must not be null")); + return Mono.error(McpError.builder(ErrorCodes.INVALID_PARAMS) + .message("Completion ref type must not be null") + .build()); } String type = request.ref().type(); String argumentName = request.argument().name(); - // check if the referenced resource exists - if (type.equals("ref/prompt") && request.ref() instanceof McpSchema.PromptReference promptReference) { + // Check if valid a Prompt exists for this completion request + if (type.equals(PromptReference.TYPE) + && request.ref() instanceof McpSchema.PromptReference promptReference) { + McpServerFeatures.AsyncPromptSpecification promptSpec = this.prompts.get(promptReference.name()); if (promptSpec == null) { - return Mono.error(new McpError("Prompt not found: " + promptReference.name())); + return Mono.error(McpError.builder(ErrorCodes.INVALID_PARAMS) + .message("Prompt not found: " + promptReference.name()) + .build()); } if (!promptSpec.prompt() .arguments() @@ -826,27 +953,67 @@ private McpRequestHandler completionCompleteRequestHan .findFirst() .isPresent()) { - return Mono.error(new McpError("Argument not found: " + argumentName)); + logger.warn("Argument not found: {} in prompt: {}", argumentName, promptReference.name()); + + return EMPTY_COMPLETION_RESULT; } } - if (type.equals("ref/resource") && request.ref() instanceof McpSchema.ResourceReference resourceReference) { - McpServerFeatures.AsyncResourceSpecification resourceSpec = this.resources.get(resourceReference.uri()); - if (resourceSpec == null) { - return Mono.error(new McpError("Resource not found: " + resourceReference.uri())); - } - if (!uriTemplateManagerFactory.create(resourceSpec.resource().uri()) - .getVariableNames() - .contains(argumentName)) { - return Mono.error(new McpError("Argument not found: " + argumentName)); + // Check if valid Resource or ResourceTemplate exists for this completion + // request + if (type.equals(ResourceReference.TYPE) + && request.ref() instanceof McpSchema.ResourceReference resourceReference) { + + var uriTemplateManager = uriTemplateManagerFactory.create(resourceReference.uri()); + + if (!uriTemplateManager.isUriTemplate(resourceReference.uri())) { + // Attempting to autocomplete a fixed resource URI is not an error in + // the spec (but probably should be). + return EMPTY_COMPLETION_RESULT; } + McpServerFeatures.AsyncResourceSpecification resourceSpec = this + .findResourceSpecification(resourceReference.uri()) + .orElse(null); + + if (resourceSpec != null) { + if (!uriTemplateManagerFactory.create(resourceSpec.resource().uri()) + .getVariableNames() + .contains(argumentName)) { + + return Mono.error(McpError.builder(ErrorCodes.INVALID_PARAMS) + .message("Argument not found: " + argumentName + " in resource: " + resourceReference.uri()) + .build()); + } + } + else { + var templateSpec = this.findResourceTemplateSpecification(resourceReference.uri()).orElse(null); + if (templateSpec != null) { + + if (!uriTemplateManagerFactory.create(templateSpec.resourceTemplate().uriTemplate()) + .getVariableNames() + .contains(argumentName)) { + + return Mono.error(McpError.builder(ErrorCodes.INVALID_PARAMS) + .message("Argument not found: " + argumentName + " in resource template: " + + resourceReference.uri()) + .build()); + } + } + else { + return Mono.error(RESOURCE_NOT_FOUND.apply(resourceReference.uri())); + } + } } + // Handle the completion request using the registered handler + // for the given reference. McpServerFeatures.AsyncCompletionSpecification specification = this.completions.get(request.ref()); if (specification == null) { - return Mono.error(new McpError("AsyncCompletionSpecification not found: " + request.ref())); + return Mono.error(McpError.builder(ErrorCodes.INVALID_PARAMS) + .message("AsyncCompletionSpecification not found: " + request.ref()) + .build()); } return Mono.defer(() -> specification.completionHandler().apply(exchange, request)); @@ -877,9 +1044,9 @@ private McpSchema.CompleteRequest parseCompletionParams(Object object) { String refType = (String) refMap.get("type"); McpSchema.CompleteReference ref = switch (refType) { - case "ref/prompt" -> new McpSchema.PromptReference(refType, (String) refMap.get("name"), + case PromptReference.TYPE -> new McpSchema.PromptReference(refType, (String) refMap.get("name"), refMap.get("title") != null ? (String) refMap.get("title") : null); - case "ref/resource" -> new McpSchema.ResourceReference(refType, (String) refMap.get("uri")); + case ResourceReference.TYPE -> new McpSchema.ResourceReference(refType, (String) refMap.get("uri")); default -> throw new IllegalArgumentException("Invalid ref type: " + refType); }; diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java similarity index 95% rename from mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java rename to mcp-core/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java index 61d60bacc..a15c58cd5 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java @@ -4,10 +4,11 @@ package io.modelcontextprotocol.server; +import io.modelcontextprotocol.common.McpTransportContext; import java.util.ArrayList; import java.util.Collections; -import com.fasterxml.jackson.core.type.TypeReference; +import io.modelcontextprotocol.json.TypeRef; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpLoggableSession; import io.modelcontextprotocol.spec.McpSchema; @@ -36,16 +37,16 @@ public class McpAsyncServerExchange { private final McpTransportContext transportContext; - private static final TypeReference CREATE_MESSAGE_RESULT_TYPE_REF = new TypeReference<>() { + private static final TypeRef CREATE_MESSAGE_RESULT_TYPE_REF = new TypeRef<>() { }; - private static final TypeReference LIST_ROOTS_RESULT_TYPE_REF = new TypeReference<>() { + private static final TypeRef LIST_ROOTS_RESULT_TYPE_REF = new TypeRef<>() { }; - private static final TypeReference ELICITATION_RESULT_TYPE_REF = new TypeReference<>() { + private static final TypeRef ELICITATION_RESULT_TYPE_REF = new TypeRef<>() { }; - public static final TypeReference OBJECT_TYPE_REF = new TypeReference<>() { + public static final TypeRef OBJECT_TYPE_REF = new TypeRef<>() { }; /** diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpInitRequestHandler.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpInitRequestHandler.java similarity index 100% rename from mcp/src/main/java/io/modelcontextprotocol/server/McpInitRequestHandler.java rename to mcp-core/src/main/java/io/modelcontextprotocol/server/McpInitRequestHandler.java diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpNotificationHandler.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpNotificationHandler.java similarity index 100% rename from mcp/src/main/java/io/modelcontextprotocol/server/McpNotificationHandler.java rename to mcp-core/src/main/java/io/modelcontextprotocol/server/McpNotificationHandler.java diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpRequestHandler.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpRequestHandler.java similarity index 100% rename from mcp/src/main/java/io/modelcontextprotocol/server/McpRequestHandler.java rename to mcp-core/src/main/java/io/modelcontextprotocol/server/McpRequestHandler.java diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpServer.java similarity index 89% rename from mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java rename to mcp-core/src/main/java/io/modelcontextprotocol/server/McpServer.java index f5dfffffb..fe3125271 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpServer.java @@ -4,6 +4,7 @@ package io.modelcontextprotocol.server; +import io.modelcontextprotocol.common.McpTransportContext; import java.time.Duration; import java.util.ArrayList; import java.util.Arrays; @@ -13,10 +14,9 @@ import java.util.function.BiConsumer; import java.util.function.BiFunction; -import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.json.McpJsonMapper; -import io.modelcontextprotocol.spec.DefaultJsonSchemaValidator; -import io.modelcontextprotocol.spec.JsonSchemaValidator; +import io.modelcontextprotocol.json.schema.JsonSchemaValidator; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; import io.modelcontextprotocol.spec.McpSchema.ResourceTemplate; @@ -24,7 +24,7 @@ import io.modelcontextprotocol.spec.McpStatelessServerTransport; import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; import io.modelcontextprotocol.util.Assert; -import io.modelcontextprotocol.util.DeafaultMcpUriTemplateManagerFactory; +import io.modelcontextprotocol.util.DefaultMcpUriTemplateManagerFactory; import io.modelcontextprotocol.util.McpUriTemplateManagerFactory; import reactor.core.publisher.Mono; @@ -66,17 +66,23 @@ * Example of creating a basic synchronous server:
{@code
  * McpServer.sync(transportProvider)
  *     .serverInfo("my-server", "1.0.0")
- *     .tool(new Tool("calculator", "Performs calculations", schema),
- *           (exchange, args) -> new CallToolResult("Result: " + calculate(args)))
+ *     .tool(Tool.builder().name("calculator").title("Performs calculations").inputSchema(schema).build(),
+ *           (exchange, args) -> CallToolResult.builder()
+ *                   .content(List.of(new McpSchema.TextContent("Result: " + calculate(args))))
+ *                   .isError(false)
+ *                   .build())
  *     .build();
  * }
* * Example of creating a basic asynchronous server:
{@code
  * McpServer.async(transportProvider)
  *     .serverInfo("my-server", "1.0.0")
- *     .tool(new Tool("calculator", "Performs calculations", schema),
+ *     .tool(Tool.builder().name("calculator").title("Performs calculations").inputSchema(schema).build(),
  *           (exchange, args) -> Mono.fromSupplier(() -> calculate(args))
- *               .map(result -> new CallToolResult("Result: " + result)))
+ *               .map(result -> CallToolResult.builder()
+ *                   .content(List.of(new McpSchema.TextContent("Result: " + result)))
+ *                   .isError(false)
+ *                   .build()))
  *     .build();
  * }
* @@ -90,12 +96,18 @@ * McpServerFeatures.AsyncToolSpecification.builder() * .tool(calculatorTool) * .callTool((exchange, args) -> Mono.fromSupplier(() -> calculate(args.arguments())) - * .map(result -> new CallToolResult("Result: " + result)))) + * .map(result -> CallToolResult.builder() + * .content(List.of(new McpSchema.TextContent("Result: " + result))) + * .isError(false) + * .build())) *. .build(), * McpServerFeatures.AsyncToolSpecification.builder() * .tool((weatherTool) * .callTool((exchange, args) -> Mono.fromSupplier(() -> getWeather(args.arguments())) - * .map(result -> new CallToolResult("Weather: " + result)))) + * .map(result -> CallToolResult.builder() + * .content(List.of(new McpSchema.TextContent("Weather: " + result))) + * .isError(false) + * .build())) * .build() * ) * // Register resources @@ -133,7 +145,7 @@ */ public interface McpServer { - McpSchema.Implementation DEFAULT_SERVER_INFO = new McpSchema.Implementation("mcp-server", "1.0.0"); + McpSchema.Implementation DEFAULT_SERVER_INFO = new McpSchema.Implementation("Java SDK MCP Server", "0.15.0"); /** * Starts building a synchronous MCP server that provides blocking operations. @@ -226,11 +238,12 @@ public McpAsyncServer build() { var features = new McpServerFeatures.Async(this.serverInfo, this.serverCapabilities, this.tools, this.resources, this.resourceTemplates, this.prompts, this.completions, this.rootsChangeHandlers, this.instructions); - var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); - var jsonSchemaValidator = this.jsonSchemaValidator != null ? this.jsonSchemaValidator - : new DefaultJsonSchemaValidator(mapper); - return new McpAsyncServer(this.transportProvider, mapper, features, this.requestTimeout, - this.uriTemplateManagerFactory, jsonSchemaValidator); + + var jsonSchemaValidator = (this.jsonSchemaValidator != null) ? this.jsonSchemaValidator + : JsonSchemaValidator.getDefault(); + + return new McpAsyncServer(transportProvider, jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, + features, requestTimeout, uriTemplateManagerFactory, jsonSchemaValidator); } } @@ -253,11 +266,10 @@ public McpAsyncServer build() { var features = new McpServerFeatures.Async(this.serverInfo, this.serverCapabilities, this.tools, this.resources, this.resourceTemplates, this.prompts, this.completions, this.rootsChangeHandlers, this.instructions); - var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); var jsonSchemaValidator = this.jsonSchemaValidator != null ? this.jsonSchemaValidator - : new DefaultJsonSchemaValidator(mapper); - return new McpAsyncServer(this.transportProvider, mapper, features, this.requestTimeout, - this.uriTemplateManagerFactory, jsonSchemaValidator); + : JsonSchemaValidator.getDefault(); + return new McpAsyncServer(transportProvider, jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, + features, requestTimeout, uriTemplateManagerFactory, jsonSchemaValidator); } } @@ -267,9 +279,9 @@ public McpAsyncServer build() { */ abstract class AsyncSpecification> { - McpUriTemplateManagerFactory uriTemplateManagerFactory = new DeafaultMcpUriTemplateManagerFactory(); + McpUriTemplateManagerFactory uriTemplateManagerFactory = new DefaultMcpUriTemplateManagerFactory(); - ObjectMapper objectMapper; + McpJsonMapper jsonMapper; McpSchema.Implementation serverInfo = DEFAULT_SERVER_INFO; @@ -297,7 +309,14 @@ abstract class AsyncSpecification> { */ final Map resources = new HashMap<>(); - final List resourceTemplates = new ArrayList<>(); + /** + * The Model Context Protocol (MCP) provides a standardized way for servers to + * expose resource templates to clients. Resource templates allow servers to + * define parameterized URIs that clients can use to access dynamic resources. + * Each resource template includes variables that clients can fill in to form + * concrete resource URIs. + */ + final Map resourceTemplates = new HashMap<>(); /** * The Model Context Protocol (MCP) provides a standardized way for servers to @@ -416,9 +435,12 @@ public AsyncSpecification capabilities(McpSchema.ServerCapabilities serverCap *

* Example usage:

{@code
 		 * .tool(
-		 *     new Tool("calculator", "Performs calculations", schema),
+		 *     Tool.builder().name("calculator").title("Performs calculations").inputSchema(schema).build(),
 		 *     (exchange, args) -> Mono.fromSupplier(() -> calculate(args))
-		 *         .map(result -> new CallToolResult("Result: " + result))
+		 *         .map(result -> CallToolResult.builder()
+		 *                   .content(List.of(new McpSchema.TextContent("Result: " + result)))
+		 *                   .isError(false)
+		 *                   .build()))
 		 * )
 		 * }
* @param tool The tool definition including name, description, and schema. Must @@ -584,40 +606,38 @@ public AsyncSpecification resources(McpServerFeatures.AsyncResourceSpecificat } /** - * Sets the resource templates that define patterns for dynamic resource access. - * Templates use URI patterns with placeholders that can be filled at runtime. - * - *

- * Example usage:

{@code
-		 * .resourceTemplates(
-		 *     new ResourceTemplate("file://{path}", "Access files by path"),
-		 *     new ResourceTemplate("db://{table}/{id}", "Access database records")
-		 * )
-		 * }
- * @param resourceTemplates List of resource templates. If null, clears existing - * templates. + * Registers multiple resource templates with their specifications using a List. + * This method is useful when resource templates need to be added in bulk from a + * collection. + * @param resourceTemplates Map of template URI to specification. Must not be + * null. * @return This builder instance for method chaining * @throws IllegalArgumentException if resourceTemplates is null. - * @see #resourceTemplates(ResourceTemplate...) */ - public AsyncSpecification resourceTemplates(List resourceTemplates) { + public AsyncSpecification resourceTemplates( + List resourceTemplates) { Assert.notNull(resourceTemplates, "Resource templates must not be null"); - this.resourceTemplates.addAll(resourceTemplates); + for (var resourceTemplate : resourceTemplates) { + this.resourceTemplates.put(resourceTemplate.resourceTemplate().uriTemplate(), resourceTemplate); + } return this; } /** - * Sets the resource templates using varargs for convenience. This is an - * alternative to {@link #resourceTemplates(List)}. - * @param resourceTemplates The resource templates to set. + * Registers multiple resource templates with their specifications using a List. + * This method is useful when resource templates need to be added in bulk from a + * collection. + * @param resourceTemplates List of template URI to specification. Must not be + * null. * @return This builder instance for method chaining * @throws IllegalArgumentException if resourceTemplates is null. * @see #resourceTemplates(List) */ - public AsyncSpecification resourceTemplates(ResourceTemplate... resourceTemplates) { + public AsyncSpecification resourceTemplates( + McpServerFeatures.AsyncResourceTemplateSpecification... resourceTemplates) { Assert.notNull(resourceTemplates, "Resource templates must not be null"); - for (ResourceTemplate resourceTemplate : resourceTemplates) { - this.resourceTemplates.add(resourceTemplate); + for (McpServerFeatures.AsyncResourceTemplateSpecification resource : resourceTemplates) { + this.resourceTemplates.put(resource.resourceTemplate().uriTemplate(), resource); } return this; } @@ -764,14 +784,14 @@ public AsyncSpecification rootsChangeHandlers( } /** - * Sets the object mapper to use for serializing and deserializing JSON messages. - * @param objectMapper the instance to use. Must not be null. + * Sets the JsonMapper to use for serializing and deserializing JSON messages. + * @param jsonMapper the mapper to use. Must not be null. * @return This builder instance for method chaining. - * @throws IllegalArgumentException if objectMapper is null + * @throws IllegalArgumentException if jsonMapper is null */ - public AsyncSpecification objectMapper(ObjectMapper objectMapper) { - Assert.notNull(objectMapper, "ObjectMapper must not be null"); - this.objectMapper = objectMapper; + public AsyncSpecification jsonMapper(McpJsonMapper jsonMapper) { + Assert.notNull(jsonMapper, "JsonMapper must not be null"); + this.jsonMapper = jsonMapper; return this; } @@ -812,13 +832,11 @@ public McpSyncServer build() { this.rootsChangeHandlers, this.instructions); McpServerFeatures.Async asyncFeatures = McpServerFeatures.Async.fromSync(syncFeatures, this.immediateExecution); - var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); - var jsonSchemaValidator = this.jsonSchemaValidator != null ? this.jsonSchemaValidator - : new DefaultJsonSchemaValidator(mapper); - - var asyncServer = new McpAsyncServer(this.transportProvider, mapper, asyncFeatures, this.requestTimeout, - this.uriTemplateManagerFactory, jsonSchemaValidator); + var asyncServer = new McpAsyncServer(transportProvider, + jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, asyncFeatures, requestTimeout, + uriTemplateManagerFactory, + jsonSchemaValidator != null ? jsonSchemaValidator : JsonSchemaValidator.getDefault()); return new McpSyncServer(asyncServer, this.immediateExecution); } @@ -845,13 +863,11 @@ public McpSyncServer build() { this.rootsChangeHandlers, this.instructions); McpServerFeatures.Async asyncFeatures = McpServerFeatures.Async.fromSync(syncFeatures, this.immediateExecution); - var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); var jsonSchemaValidator = this.jsonSchemaValidator != null ? this.jsonSchemaValidator - : new DefaultJsonSchemaValidator(mapper); - - var asyncServer = new McpAsyncServer(this.transportProvider, mapper, asyncFeatures, this.requestTimeout, + : JsonSchemaValidator.getDefault(); + var asyncServer = new McpAsyncServer(transportProvider, + jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, asyncFeatures, this.requestTimeout, this.uriTemplateManagerFactory, jsonSchemaValidator); - return new McpSyncServer(asyncServer, this.immediateExecution); } @@ -862,9 +878,9 @@ public McpSyncServer build() { */ abstract class SyncSpecification> { - McpUriTemplateManagerFactory uriTemplateManagerFactory = new DeafaultMcpUriTemplateManagerFactory(); + McpUriTemplateManagerFactory uriTemplateManagerFactory = new DefaultMcpUriTemplateManagerFactory(); - ObjectMapper objectMapper; + McpJsonMapper jsonMapper; McpSchema.Implementation serverInfo = DEFAULT_SERVER_INFO; @@ -890,7 +906,14 @@ abstract class SyncSpecification> { */ final Map resources = new HashMap<>(); - final List resourceTemplates = new ArrayList<>(); + /** + * The Model Context Protocol (MCP) provides a standardized way for servers to + * expose resource templates to clients. Resource templates allow servers to + * define parameterized URIs that clients can use to access dynamic resources. + * Each resource template includes variables that clients can fill in to form + * concrete resource URIs. + */ + final Map resourceTemplates = new HashMap<>(); JsonSchemaValidator jsonSchemaValidator; @@ -1013,8 +1036,11 @@ public SyncSpecification capabilities(McpSchema.ServerCapabilities serverCapa *

* Example usage:

{@code
 		 * .tool(
-		 *     new Tool("calculator", "Performs calculations", schema),
-		 *     (exchange, args) -> new CallToolResult("Result: " + calculate(args))
+		 *     Tool.builder().name("calculator").title("Performs calculations".inputSchema(schema).build(),
+		 *     (exchange, args) -> CallToolResult.builder()
+		 *                   .content(List.of(new McpSchema.TextContent("Result: " + calculate(args))))
+		 *                   .isError(false)
+		 *                   .build())
 		 * )
 		 * }
* @param tool The tool definition including name, description, and schema. Must @@ -1182,23 +1208,17 @@ public SyncSpecification resources(McpServerFeatures.SyncResourceSpecificatio /** * Sets the resource templates that define patterns for dynamic resource access. * Templates use URI patterns with placeholders that can be filled at runtime. - * - *

- * Example usage:

{@code
-		 * .resourceTemplates(
-		 *     new ResourceTemplate("file://{path}", "Access files by path"),
-		 *     new ResourceTemplate("db://{table}/{id}", "Access database records")
-		 * )
-		 * }
- * @param resourceTemplates List of resource templates. If null, clears existing - * templates. + * @param resourceTemplates List of resource template specifications. Must not be + * null. * @return This builder instance for method chaining * @throws IllegalArgumentException if resourceTemplates is null. - * @see #resourceTemplates(ResourceTemplate...) */ - public SyncSpecification resourceTemplates(List resourceTemplates) { + public SyncSpecification resourceTemplates( + List resourceTemplates) { Assert.notNull(resourceTemplates, "Resource templates must not be null"); - this.resourceTemplates.addAll(resourceTemplates); + for (McpServerFeatures.SyncResourceTemplateSpecification resource : resourceTemplates) { + this.resourceTemplates.put(resource.resourceTemplate().uriTemplate(), resource); + } return this; } @@ -1210,10 +1230,11 @@ public SyncSpecification resourceTemplates(List resourceTem * @throws IllegalArgumentException if resourceTemplates is null * @see #resourceTemplates(List) */ - public SyncSpecification resourceTemplates(ResourceTemplate... resourceTemplates) { + public SyncSpecification resourceTemplates( + McpServerFeatures.SyncResourceTemplateSpecification... resourceTemplates) { Assert.notNull(resourceTemplates, "Resource templates must not be null"); - for (ResourceTemplate resourceTemplate : resourceTemplates) { - this.resourceTemplates.add(resourceTemplate); + for (McpServerFeatures.SyncResourceTemplateSpecification resourceTemplate : resourceTemplates) { + this.resourceTemplates.put(resourceTemplate.resourceTemplate().uriTemplate(), resourceTemplate); } return this; } @@ -1362,14 +1383,14 @@ public SyncSpecification rootsChangeHandlers( } /** - * Sets the object mapper to use for serializing and deserializing JSON messages. - * @param objectMapper the instance to use. Must not be null. + * Sets the JsonMapper to use for serializing and deserializing JSON messages. + * @param jsonMapper the mapper to use. Must not be null. * @return This builder instance for method chaining. - * @throws IllegalArgumentException if objectMapper is null + * @throws IllegalArgumentException if jsonMapper is null */ - public SyncSpecification objectMapper(ObjectMapper objectMapper) { - Assert.notNull(objectMapper, "ObjectMapper must not be null"); - this.objectMapper = objectMapper; + public SyncSpecification jsonMapper(McpJsonMapper jsonMapper) { + Assert.notNull(jsonMapper, "JsonMapper must not be null"); + this.jsonMapper = jsonMapper; return this; } @@ -1401,9 +1422,9 @@ class StatelessAsyncSpecification { private final McpStatelessServerTransport transport; - McpUriTemplateManagerFactory uriTemplateManagerFactory = new DeafaultMcpUriTemplateManagerFactory(); + McpUriTemplateManagerFactory uriTemplateManagerFactory = new DefaultMcpUriTemplateManagerFactory(); - ObjectMapper objectMapper; + McpJsonMapper jsonMapper; McpSchema.Implementation serverInfo = DEFAULT_SERVER_INFO; @@ -1431,7 +1452,14 @@ class StatelessAsyncSpecification { */ final Map resources = new HashMap<>(); - final List resourceTemplates = new ArrayList<>(); + /** + * The Model Context Protocol (MCP) provides a standardized way for servers to + * expose resource templates to clients. Resource templates allow servers to + * define parameterized URIs that clients can use to access dynamic resources. + * Each resource template includes variables that clients can fill in to form + * concrete resource URIs. + */ + final Map resourceTemplates = new HashMap<>(); /** * The Model Context Protocol (MCP) provides a standardized way for servers to @@ -1687,23 +1715,17 @@ public StatelessAsyncSpecification resources( /** * Sets the resource templates that define patterns for dynamic resource access. * Templates use URI patterns with placeholders that can be filled at runtime. - * - *

- * Example usage:

{@code
-		 * .resourceTemplates(
-		 *     new ResourceTemplate("file://{path}", "Access files by path"),
-		 *     new ResourceTemplate("db://{table}/{id}", "Access database records")
-		 * )
-		 * }
* @param resourceTemplates List of resource templates. If null, clears existing * templates. * @return This builder instance for method chaining * @throws IllegalArgumentException if resourceTemplates is null. - * @see #resourceTemplates(ResourceTemplate...) */ - public StatelessAsyncSpecification resourceTemplates(List resourceTemplates) { + public StatelessAsyncSpecification resourceTemplates( + List resourceTemplates) { Assert.notNull(resourceTemplates, "Resource templates must not be null"); - this.resourceTemplates.addAll(resourceTemplates); + for (var resourceTemplate : resourceTemplates) { + this.resourceTemplates.put(resourceTemplate.resourceTemplate().uriTemplate(), resourceTemplate); + } return this; } @@ -1715,10 +1737,11 @@ public StatelessAsyncSpecification resourceTemplates(List reso * @throws IllegalArgumentException if resourceTemplates is null. * @see #resourceTemplates(List) */ - public StatelessAsyncSpecification resourceTemplates(ResourceTemplate... resourceTemplates) { + public StatelessAsyncSpecification resourceTemplates( + McpStatelessServerFeatures.AsyncResourceTemplateSpecification... resourceTemplates) { Assert.notNull(resourceTemplates, "Resource templates must not be null"); - for (ResourceTemplate resourceTemplate : resourceTemplates) { - this.resourceTemplates.add(resourceTemplate); + for (McpStatelessServerFeatures.AsyncResourceTemplateSpecification resourceTemplate : resourceTemplates) { + this.resourceTemplates.put(resourceTemplate.resourceTemplate().uriTemplate(), resourceTemplate); } return this; } @@ -1820,14 +1843,14 @@ public StatelessAsyncSpecification completions( } /** - * Sets the object mapper to use for serializing and deserializing JSON messages. - * @param objectMapper the instance to use. Must not be null. + * Sets the JsonMapper to use for serializing and deserializing JSON messages. + * @param jsonMapper the mapper to use. Must not be null. * @return This builder instance for method chaining. - * @throws IllegalArgumentException if objectMapper is null + * @throws IllegalArgumentException if jsonMapper is null */ - public StatelessAsyncSpecification objectMapper(ObjectMapper objectMapper) { - Assert.notNull(objectMapper, "ObjectMapper must not be null"); - this.objectMapper = objectMapper; + public StatelessAsyncSpecification jsonMapper(McpJsonMapper jsonMapper) { + Assert.notNull(jsonMapper, "JsonMapper must not be null"); + this.jsonMapper = jsonMapper; return this; } @@ -1848,11 +1871,9 @@ public StatelessAsyncSpecification jsonSchemaValidator(JsonSchemaValidator jsonS public McpStatelessAsyncServer build() { var features = new McpStatelessServerFeatures.Async(this.serverInfo, this.serverCapabilities, this.tools, this.resources, this.resourceTemplates, this.prompts, this.completions, this.instructions); - var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); - var jsonSchemaValidator = this.jsonSchemaValidator != null ? this.jsonSchemaValidator - : new DefaultJsonSchemaValidator(mapper); - return new McpStatelessAsyncServer(this.transport, mapper, features, this.requestTimeout, - this.uriTemplateManagerFactory, jsonSchemaValidator); + return new McpStatelessAsyncServer(transport, jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, + features, requestTimeout, uriTemplateManagerFactory, + jsonSchemaValidator != null ? jsonSchemaValidator : JsonSchemaValidator.getDefault()); } } @@ -1863,9 +1884,9 @@ class StatelessSyncSpecification { boolean immediateExecution = false; - McpUriTemplateManagerFactory uriTemplateManagerFactory = new DeafaultMcpUriTemplateManagerFactory(); + McpUriTemplateManagerFactory uriTemplateManagerFactory = new DefaultMcpUriTemplateManagerFactory(); - ObjectMapper objectMapper; + McpJsonMapper jsonMapper; McpSchema.Implementation serverInfo = DEFAULT_SERVER_INFO; @@ -1893,7 +1914,14 @@ class StatelessSyncSpecification { */ final Map resources = new HashMap<>(); - final List resourceTemplates = new ArrayList<>(); + /** + * The Model Context Protocol (MCP) provides a standardized way for servers to + * expose resource templates to clients. Resource templates allow servers to + * define parameterized URIs that clients can use to access dynamic resources. + * Each resource template includes variables that clients can fill in to form + * concrete resource URIs. + */ + final Map resourceTemplates = new HashMap<>(); /** * The Model Context Protocol (MCP) provides a standardized way for servers to @@ -2149,23 +2177,17 @@ public StatelessSyncSpecification resources( /** * Sets the resource templates that define patterns for dynamic resource access. * Templates use URI patterns with placeholders that can be filled at runtime. - * - *

- * Example usage:

{@code
-		 * .resourceTemplates(
-		 *     new ResourceTemplate("file://{path}", "Access files by path"),
-		 *     new ResourceTemplate("db://{table}/{id}", "Access database records")
-		 * )
-		 * }
- * @param resourceTemplates List of resource templates. If null, clears existing - * templates. + * @param resourceTemplatesSpec List of resource templates. If null, clears + * existing templates. * @return This builder instance for method chaining * @throws IllegalArgumentException if resourceTemplates is null. - * @see #resourceTemplates(ResourceTemplate...) */ - public StatelessSyncSpecification resourceTemplates(List resourceTemplates) { - Assert.notNull(resourceTemplates, "Resource templates must not be null"); - this.resourceTemplates.addAll(resourceTemplates); + public StatelessSyncSpecification resourceTemplates( + List resourceTemplatesSpec) { + Assert.notNull(resourceTemplatesSpec, "Resource templates must not be null"); + for (var resourceTemplate : resourceTemplatesSpec) { + this.resourceTemplates.put(resourceTemplate.resourceTemplate().uriTemplate(), resourceTemplate); + } return this; } @@ -2177,10 +2199,11 @@ public StatelessSyncSpecification resourceTemplates(List resou * @throws IllegalArgumentException if resourceTemplates is null. * @see #resourceTemplates(List) */ - public StatelessSyncSpecification resourceTemplates(ResourceTemplate... resourceTemplates) { + public StatelessSyncSpecification resourceTemplates( + McpStatelessServerFeatures.SyncResourceTemplateSpecification... resourceTemplates) { Assert.notNull(resourceTemplates, "Resource templates must not be null"); - for (ResourceTemplate resourceTemplate : resourceTemplates) { - this.resourceTemplates.add(resourceTemplate); + for (McpStatelessServerFeatures.SyncResourceTemplateSpecification resourceTemplate : resourceTemplates) { + this.resourceTemplates.put(resourceTemplate.resourceTemplate().uriTemplate(), resourceTemplate); } return this; } @@ -2282,14 +2305,14 @@ public StatelessSyncSpecification completions( } /** - * Sets the object mapper to use for serializing and deserializing JSON messages. - * @param objectMapper the instance to use. Must not be null. + * Sets the JsonMapper to use for serializing and deserializing JSON messages. + * @param jsonMapper the mapper to use. Must not be null. * @return This builder instance for method chaining. - * @throws IllegalArgumentException if objectMapper is null + * @throws IllegalArgumentException if jsonMapper is null */ - public StatelessSyncSpecification objectMapper(ObjectMapper objectMapper) { - Assert.notNull(objectMapper, "ObjectMapper must not be null"); - this.objectMapper = objectMapper; + public StatelessSyncSpecification jsonMapper(McpJsonMapper jsonMapper) { + Assert.notNull(jsonMapper, "JsonMapper must not be null"); + this.jsonMapper = jsonMapper; return this; } @@ -2324,31 +2347,13 @@ public StatelessSyncSpecification immediateExecution(boolean immediateExecution) } public McpStatelessSyncServer build() { - /* - * McpServerFeatures.Sync syncFeatures = new - * McpServerFeatures.Sync(this.serverInfo, this.serverCapabilities, - * this.tools, this.resources, this.resourceTemplates, this.prompts, - * this.completions, this.rootsChangeHandlers, this.instructions); - * McpServerFeatures.Async asyncFeatures = - * McpServerFeatures.Async.fromSync(syncFeatures, this.immediateExecution); - * var mapper = this.objectMapper != null ? this.objectMapper : new - * ObjectMapper(); var jsonSchemaValidator = this.jsonSchemaValidator != null - * ? this.jsonSchemaValidator : new DefaultJsonSchemaValidator(mapper); - * - * var asyncServer = new McpAsyncServer(this.transportProvider, mapper, - * asyncFeatures, this.requestTimeout, this.uriTemplateManagerFactory, - * jsonSchemaValidator); - * - * return new McpSyncServer(asyncServer, this.immediateExecution); - */ var syncFeatures = new McpStatelessServerFeatures.Sync(this.serverInfo, this.serverCapabilities, this.tools, this.resources, this.resourceTemplates, this.prompts, this.completions, this.instructions); var asyncFeatures = McpStatelessServerFeatures.Async.fromSync(syncFeatures, this.immediateExecution); - var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); - var jsonSchemaValidator = this.jsonSchemaValidator != null ? this.jsonSchemaValidator - : new DefaultJsonSchemaValidator(mapper); - var asyncServer = new McpStatelessAsyncServer(this.transport, mapper, asyncFeatures, this.requestTimeout, - this.uriTemplateManagerFactory, jsonSchemaValidator); + var asyncServer = new McpStatelessAsyncServer(transport, + jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, asyncFeatures, requestTimeout, + uriTemplateManagerFactory, + this.jsonSchemaValidator != null ? this.jsonSchemaValidator : JsonSchemaValidator.getDefault()); return new McpStatelessSyncServer(asyncServer, this.immediateExecution); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java similarity index 83% rename from mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java rename to mcp-core/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java index 12edfb341..fe0608b1c 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java @@ -41,7 +41,7 @@ public class McpServerFeatures { */ record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, List tools, Map resources, - List resourceTemplates, + Map resourceTemplates, Map prompts, Map completions, List, Mono>> rootsChangeConsumers, @@ -53,7 +53,7 @@ record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities s * @param serverCapabilities The server capabilities * @param tools The list of tool specifications * @param resources The map of resource specifications - * @param resourceTemplates The list of resource templates + * @param resourceTemplates The map of resource templates * @param prompts The map of prompt specifications * @param rootsChangeConsumers The list of consumers that will be notified when * the roots list changes @@ -61,7 +61,7 @@ record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities s */ Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, List tools, Map resources, - List resourceTemplates, + Map resourceTemplates, Map prompts, Map completions, List, Mono>> rootsChangeConsumers, @@ -84,7 +84,7 @@ record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities s this.tools = (tools != null) ? tools : List.of(); this.resources = (resources != null) ? resources : Map.of(); - this.resourceTemplates = (resourceTemplates != null) ? resourceTemplates : List.of(); + this.resourceTemplates = (resourceTemplates != null) ? resourceTemplates : Map.of(); this.prompts = (prompts != null) ? prompts : Map.of(); this.completions = (completions != null) ? completions : Map.of(); this.rootsChangeConsumers = (rootsChangeConsumers != null) ? rootsChangeConsumers : List.of(); @@ -112,6 +112,11 @@ static Async fromSync(Sync syncSpec, boolean immediateExecution) { resources.put(key, AsyncResourceSpecification.fromSync(resource, immediateExecution)); }); + Map resourceTemplates = new HashMap<>(); + syncSpec.resourceTemplates().forEach((key, resource) -> { + resourceTemplates.put(key, AsyncResourceTemplateSpecification.fromSync(resource, immediateExecution)); + }); + Map prompts = new HashMap<>(); syncSpec.prompts().forEach((key, prompt) -> { prompts.put(key, AsyncPromptSpecification.fromSync(prompt, immediateExecution)); @@ -130,8 +135,8 @@ static Async fromSync(Sync syncSpec, boolean immediateExecution) { .subscribeOn(Schedulers.boundedElastic())); } - return new Async(syncSpec.serverInfo(), syncSpec.serverCapabilities(), tools, resources, - syncSpec.resourceTemplates(), prompts, completions, rootChangeConsumers, syncSpec.instructions()); + return new Async(syncSpec.serverInfo(), syncSpec.serverCapabilities(), tools, resources, resourceTemplates, + prompts, completions, rootChangeConsumers, syncSpec.instructions()); } } @@ -151,7 +156,7 @@ static Async fromSync(Sync syncSpec, boolean immediateExecution) { record Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, List tools, Map resources, - List resourceTemplates, + Map resourceTemplates, Map prompts, Map completions, List>> rootsChangeConsumers, String instructions) { @@ -171,7 +176,7 @@ record Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities se Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, List tools, Map resources, - List resourceTemplates, + Map resourceTemplates, Map prompts, Map completions, List>> rootsChangeConsumers, @@ -194,7 +199,7 @@ record Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities se this.tools = (tools != null) ? tools : new ArrayList<>(); this.resources = (resources != null) ? resources : new HashMap<>(); - this.resourceTemplates = (resourceTemplates != null) ? resourceTemplates : new ArrayList<>(); + this.resourceTemplates = (resourceTemplates != null) ? resourceTemplates : Map.of(); this.prompts = (prompts != null) ? prompts : new HashMap<>(); this.completions = (completions != null) ? completions : new HashMap<>(); this.rootsChangeConsumers = (rootsChangeConsumers != null) ? rootsChangeConsumers : new ArrayList<>(); @@ -329,7 +334,13 @@ public static Builder builder() { * *
{@code
 	 * new McpServerFeatures.AsyncResourceSpecification(
-	 * 		new Resource("docs", "Documentation files", "text/markdown"),
+	 *     Resource.builder()
+	 *         .uri("docs")
+	 *         .name("Documentation files")
+	 * 		   .title("Documentation files")
+	 * 		   .mimeType("text/markdown")
+	 * 		   .description("Markdown documentation files")
+	 * 		   .build(),
 	 * 		(exchange, request) -> Mono.fromSupplier(() -> readFile(request.getPath()))
 	 * 				.map(ReadResourceResult::new))
 	 * }
@@ -356,6 +367,47 @@ static AsyncResourceSpecification fromSync(SyncResourceSpecification resource, b } } + /** + * Specification of a resource template with its synchronous handler function. + * Resource templates allow servers to expose parameterized resources using URI + * templates: URI + * templates.. Arguments may be auto-completed through the + * completion API. + * + * Templates support: + *
    + *
  • Parameterized resource definitions + *
  • Dynamic content generation + *
  • Consistent resource formatting + *
  • Contextual data injection + *
+ * + * @param resourceTemplate The resource template definition including name, + * description, and parameter schema + * @param readHandler The function that handles resource read requests. The function's + * first argument is an {@link McpSyncServerExchange} upon which the server can + * interact with the connected client. The second arguments is a + * {@link McpSchema.ReadResourceRequest}. {@link McpSchema.ResourceTemplate} + * {@link McpSchema.ReadResourceResult} + */ + public record AsyncResourceTemplateSpecification(McpSchema.ResourceTemplate resourceTemplate, + BiFunction> readHandler) { + + static AsyncResourceTemplateSpecification fromSync(SyncResourceTemplateSpecification resource, + boolean immediateExecution) { + // FIXME: This is temporary, proper validation should be implemented + if (resource == null) { + return null; + } + return new AsyncResourceTemplateSpecification(resource.resourceTemplate(), (exchange, req) -> { + var resourceResult = Mono + .fromCallable(() -> resource.readHandler().apply(new McpSyncServerExchange(exchange), req)); + return immediateExecution ? resourceResult : resourceResult.subscribeOn(Schedulers.boundedElastic()); + }); + } + } + /** * Specification of a prompt template with its asynchronous handler function. Prompts * provide structured templates for AI model interactions, supporting: @@ -453,15 +505,19 @@ static AsyncCompletionSpecification fromSync(SyncCompletionSpecification complet * *
{@code
 	 * McpServerFeatures.SyncToolSpecification.builder()
-	 * 		.tool(new Tool(
-	 * 				"calculator",
-	 * 				"Performs mathematical calculations",
-	 * 				new JsonSchemaObject()
+	 * 		.tool(Tool.builder()
+	 * 				.name("calculator")
+	 * 				.title("Performs mathematical calculations")
+	 * 				.inputSchema(new JsonSchemaObject()
 	 * 						.required("expression")
-	 * 						.property("expression", JsonSchemaType.STRING)))
+	 * 						.property("expression", JsonSchemaType.STRING))
+	 * 				.build()
 	 * 		.toolHandler((exchange, req) -> {
 	 * 			String expr = (String) req.arguments().get("expression");
-	 * 			return new CallToolResult("Result: " + evaluate(expr));
+	 * 			return CallToolResult.builder()
+	 *                   .content(List.of(new McpSchema.TextContent("Result: " + evaluate(expr))))
+	 *                   .isError(false)
+	 *                   .build();
 	 * 		}))
 	 *      .build();
 	 * }
@@ -557,7 +613,13 @@ public static Builder builder() { * *
{@code
 	 * new McpServerFeatures.SyncResourceSpecification(
-	 * 		new Resource("docs", "Documentation files", "text/markdown"),
+	 *     Resource.builder()
+	 *         .uri("docs")
+	 *         .name("Documentation files")
+	 * 		   .title("Documentation files")
+	 * 		   .mimeType("text/markdown")
+	 * 		   .description("Markdown documentation files")
+	 * 		   .build(),
 	 * 		(exchange, request) -> {
 	 * 			String content = readFile(request.getPath());
 	 * 			return new ReadResourceResult(content);
@@ -574,6 +636,34 @@ public record SyncResourceSpecification(McpSchema.Resource resource,
 			BiFunction readHandler) {
 	}
 
+	/**
+	 * Specification of a resource template with its synchronous handler function.
+	 * Resource templates allow servers to expose parameterized resources using URI
+	 * templates:  URI
+	 * templates.. Arguments may be auto-completed through the
+	 * completion API.
+	 *
+	 * Templates support:
+	 * 
    + *
  • Parameterized resource definitions + *
  • Dynamic content generation + *
  • Consistent resource formatting + *
  • Contextual data injection + *
+ * + * @param resourceTemplate The resource template definition including name, + * description, and parameter schema + * @param readHandler The function that handles resource read requests. The function's + * first argument is an {@link McpSyncServerExchange} upon which the server can + * interact with the connected client. The second arguments is a + * {@link McpSchema.ReadResourceRequest}. {@link McpSchema.ResourceTemplate} + * {@link McpSchema.ReadResourceResult} + */ + public record SyncResourceTemplateSpecification(McpSchema.ResourceTemplate resourceTemplate, + BiFunction readHandler) { + } + /** * Specification of a prompt template with its synchronous handler function. Prompts * provide structured templates for AI model interactions, supporting: diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessAsyncServer.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessAsyncServer.java similarity index 60% rename from mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessAsyncServer.java rename to mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessAsyncServer.java index 41e0e9588..c7a1fd0d7 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessAsyncServer.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessAsyncServer.java @@ -4,21 +4,27 @@ package io.modelcontextprotocol.server; -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.spec.JsonSchemaValidator; +import io.modelcontextprotocol.json.TypeRef; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.json.schema.JsonSchemaValidator; +import io.modelcontextprotocol.server.McpStatelessServerFeatures.AsyncResourceTemplateSpecification; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; -import io.modelcontextprotocol.spec.McpSchema.ResourceTemplate; +import io.modelcontextprotocol.spec.McpSchema.CompleteResult.CompleteCompletion; +import io.modelcontextprotocol.spec.McpSchema.ErrorCodes; +import io.modelcontextprotocol.spec.McpSchema.PromptReference; +import io.modelcontextprotocol.spec.McpSchema.ResourceReference; import io.modelcontextprotocol.spec.McpSchema.Tool; import io.modelcontextprotocol.spec.McpStatelessServerTransport; import io.modelcontextprotocol.util.Assert; -import io.modelcontextprotocol.util.DeafaultMcpUriTemplateManagerFactory; +import io.modelcontextprotocol.util.DefaultMcpUriTemplateManagerFactory; import io.modelcontextprotocol.util.McpUriTemplateManagerFactory; import io.modelcontextprotocol.util.Utils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import java.time.Duration; @@ -31,6 +37,8 @@ import java.util.concurrent.CopyOnWriteArrayList; import java.util.function.BiFunction; +import static io.modelcontextprotocol.spec.McpError.RESOURCE_NOT_FOUND; + /** * A stateless MCP server implementation for use with Streamable HTTP transport types. It * allows simple horizontal scalability since it does not maintain a session and does not @@ -45,7 +53,7 @@ public class McpStatelessAsyncServer { private final McpStatelessServerTransport mcpTransportProvider; - private final ObjectMapper objectMapper; + private final McpJsonMapper jsonMapper; private final McpSchema.ServerCapabilities serverCapabilities; @@ -55,7 +63,7 @@ public class McpStatelessAsyncServer { private final CopyOnWriteArrayList tools = new CopyOnWriteArrayList<>(); - private final CopyOnWriteArrayList resourceTemplates = new CopyOnWriteArrayList<>(); + private final ConcurrentHashMap resourceTemplates = new ConcurrentHashMap<>(); private final ConcurrentHashMap resources = new ConcurrentHashMap<>(); @@ -65,21 +73,21 @@ public class McpStatelessAsyncServer { private List protocolVersions; - private McpUriTemplateManagerFactory uriTemplateManagerFactory = new DeafaultMcpUriTemplateManagerFactory(); + private McpUriTemplateManagerFactory uriTemplateManagerFactory = new DefaultMcpUriTemplateManagerFactory(); private final JsonSchemaValidator jsonSchemaValidator; - McpStatelessAsyncServer(McpStatelessServerTransport mcpTransport, ObjectMapper objectMapper, + McpStatelessAsyncServer(McpStatelessServerTransport mcpTransport, McpJsonMapper jsonMapper, McpStatelessServerFeatures.Async features, Duration requestTimeout, McpUriTemplateManagerFactory uriTemplateManagerFactory, JsonSchemaValidator jsonSchemaValidator) { this.mcpTransportProvider = mcpTransport; - this.objectMapper = objectMapper; + this.jsonMapper = jsonMapper; this.serverInfo = features.serverInfo(); this.serverCapabilities = features.serverCapabilities(); this.instructions = features.instructions(); this.tools.addAll(withStructuredOutputHandling(jsonSchemaValidator, features.tools())); this.resources.putAll(features.resources()); - this.resourceTemplates.addAll(features.resourceTemplates()); + this.resourceTemplates.putAll(features.resourceTemplates()); this.prompts.putAll(features.prompts()); this.completions.putAll(features.completions()); this.uriTemplateManagerFactory = uriTemplateManagerFactory; @@ -129,7 +137,7 @@ public class McpStatelessAsyncServer { // --------------------------------------- private McpStatelessRequestHandler asyncInitializeRequestHandler() { return (ctx, req) -> Mono.defer(() -> { - McpSchema.InitializeRequest initializeRequest = this.objectMapper.convertValue(req, + McpSchema.InitializeRequest initializeRequest = this.jsonMapper.convertValue(req, McpSchema.InitializeRequest.class); logger.info("Client initialize request - Protocol: {}, Capabilities: {}, Info: {}", @@ -248,6 +256,11 @@ public Mono apply(McpTransportContext transportContext, McpSchem return this.delegateHandler.apply(transportContext, request).map(result -> { + if (Boolean.TRUE.equals(result.isError())) { + // If the tool call resulted in an error, skip further validation + return result; + } + if (outputSchema == null) { if (result.structuredContent() != null) { logger.warn( @@ -263,11 +276,12 @@ public Mono apply(McpTransportContext transportContext, McpSchem // results that conform to this schema. // https://modelcontextprotocol.io/specification/2025-06-18/server/tools#output-schema if (result.structuredContent() == null) { - logger.warn( - "Response missing structured content which is expected when calling tool with non-empty outputSchema"); - return new CallToolResult( - "Response missing structured content which is expected when calling tool with non-empty outputSchema", - true); + String content = "Response missing structured content which is expected when calling tool with non-empty outputSchema"; + logger.warn(content); + return CallToolResult.builder() + .content(List.of(new McpSchema.TextContent(content))) + .isError(true) + .build(); } // Validate the result against the output schema @@ -275,7 +289,10 @@ public Mono apply(McpTransportContext transportContext, McpSchem if (!validation.valid()) { logger.warn("Tool call result validation failed: {}", validation.errorMessage()); - return new CallToolResult(validation.errorMessage(), true); + return CallToolResult.builder() + .content(List.of(new McpSchema.TextContent(validation.errorMessage()))) + .isError(true) + .build(); } if (Utils.isEmpty(result.content())) { @@ -285,8 +302,11 @@ public Mono apply(McpTransportContext transportContext, McpSchem // TextContent block.) // https://modelcontextprotocol.io/specification/2025-06-18/server/tools#structured-content - return new CallToolResult(List.of(new McpSchema.TextContent(validation.jsonStructuredOutput())), - result.isError(), result.structuredContent()); + return CallToolResult.builder() + .content(List.of(new McpSchema.TextContent(validation.jsonStructuredOutput()))) + .isError(result.isError()) + .structuredContent(result.structuredContent()) + .build(); } return result; @@ -302,25 +322,24 @@ public Mono apply(McpTransportContext transportContext, McpSchem */ public Mono addTool(McpStatelessServerFeatures.AsyncToolSpecification toolSpecification) { if (toolSpecification == null) { - return Mono.error(new McpError("Tool specification must not be null")); + return Mono.error(new IllegalArgumentException("Tool specification must not be null")); } if (toolSpecification.tool() == null) { - return Mono.error(new McpError("Tool must not be null")); + return Mono.error(new IllegalArgumentException("Tool must not be null")); } if (toolSpecification.callHandler() == null) { - return Mono.error(new McpError("Tool call handler must not be null")); + return Mono.error(new IllegalArgumentException("Tool call handler must not be null")); } if (this.serverCapabilities.tools() == null) { - return Mono.error(new McpError("Server must be configured with tool capabilities")); + return Mono.error(new IllegalStateException("Server must be configured with tool capabilities")); } var wrappedToolSpecification = withStructuredOutputHandling(this.jsonSchemaValidator, toolSpecification); return Mono.defer(() -> { - // Check for duplicate tool names - if (this.tools.stream().anyMatch(th -> th.tool().name().equals(wrappedToolSpecification.tool().name()))) { - return Mono.error( - new McpError("Tool with name '" + wrappedToolSpecification.tool().name() + "' already exists")); + // Remove tools with duplicate tool names first + if (this.tools.removeIf(th -> th.tool().name().equals(wrappedToolSpecification.tool().name()))) { + logger.warn("Replace existing Tool with name '{}'", wrappedToolSpecification.tool().name()); } this.tools.add(wrappedToolSpecification); @@ -330,6 +349,14 @@ public Mono addTool(McpStatelessServerFeatures.AsyncToolSpecification tool }); } + /** + * List all registered tools. + * @return A Flux stream of all registered tools + */ + public Flux listTools() { + return Flux.fromIterable(this.tools).map(McpStatelessServerFeatures.AsyncToolSpecification::tool); + } + /** * Remove a tool handler at runtime. * @param toolName The name of the tool handler to remove @@ -337,20 +364,22 @@ public Mono addTool(McpStatelessServerFeatures.AsyncToolSpecification tool */ public Mono removeTool(String toolName) { if (toolName == null) { - return Mono.error(new McpError("Tool name must not be null")); + return Mono.error(new IllegalArgumentException("Tool name must not be null")); } if (this.serverCapabilities.tools() == null) { - return Mono.error(new McpError("Server must be configured with tool capabilities")); + return Mono.error(new IllegalStateException("Server must be configured with tool capabilities")); } return Mono.defer(() -> { - boolean removed = this.tools - .removeIf(toolSpecification -> toolSpecification.tool().name().equals(toolName)); - if (removed) { + if (this.tools.removeIf(toolSpecification -> toolSpecification.tool().name().equals(toolName))) { + logger.debug("Removed tool handler: {}", toolName); - return Mono.empty(); } - return Mono.error(new McpError("Tool with name '" + toolName + "' not found")); + else { + logger.warn("Ignore as a Tool with name '{}' not found", toolName); + } + + return Mono.empty(); }); } @@ -365,8 +394,8 @@ private McpStatelessRequestHandler toolsListRequestHa private McpStatelessRequestHandler toolsCallRequestHandler() { return (ctx, params) -> { - McpSchema.CallToolRequest callToolRequest = objectMapper.convertValue(params, - new TypeReference() { + McpSchema.CallToolRequest callToolRequest = jsonMapper.convertValue(params, + new TypeRef() { }); Optional toolSpecification = this.tools.stream() @@ -374,11 +403,13 @@ private McpStatelessRequestHandler toolsCallRequestHandler() { .findAny(); if (toolSpecification.isEmpty()) { - return Mono.error(new McpError("Tool not found: " + callToolRequest.name())); + return Mono.error(McpError.builder(McpSchema.ErrorCodes.INVALID_PARAMS) + .message("Unknown tool: invalid_tool_name") + .data("Tool not found: " + callToolRequest.name()) + .build()); } - return toolSpecification.map(tool -> tool.callHandler().apply(ctx, callToolRequest)) - .orElse(Mono.error(new McpError("Tool not found: " + callToolRequest.name()))); + return toolSpecification.get().callHandler().apply(ctx, callToolRequest); }; } @@ -393,23 +424,34 @@ private McpStatelessRequestHandler toolsCallRequestHandler() { */ public Mono addResource(McpStatelessServerFeatures.AsyncResourceSpecification resourceSpecification) { if (resourceSpecification == null || resourceSpecification.resource() == null) { - return Mono.error(new McpError("Resource must not be null")); + return Mono.error(new IllegalArgumentException("Resource must not be null")); } if (this.serverCapabilities.resources() == null) { - return Mono.error(new McpError("Server must be configured with resource capabilities")); + return Mono.error(new IllegalStateException("Server must be configured with resource capabilities")); } return Mono.defer(() -> { - if (this.resources.putIfAbsent(resourceSpecification.resource().uri(), resourceSpecification) != null) { - return Mono.error(new McpError( - "Resource with URI '" + resourceSpecification.resource().uri() + "' already exists")); + var previous = this.resources.put(resourceSpecification.resource().uri(), resourceSpecification); + if (previous != null) { + logger.warn("Replace existing Resource with URI '{}'", resourceSpecification.resource().uri()); + } + else { + logger.debug("Added resource handler: {}", resourceSpecification.resource().uri()); } - logger.debug("Added resource handler: {}", resourceSpecification.resource().uri()); return Mono.empty(); }); } + /** + * List all registered resources. + * @return A Flux stream of all registered resources + */ + public Flux listResources() { + return Flux.fromIterable(this.resources.values()) + .map(McpStatelessServerFeatures.AsyncResourceSpecification::resource); + } + /** * Remove a resource handler at runtime. * @param resourceUri The URI of the resource handler to remove @@ -417,19 +459,83 @@ public Mono addResource(McpStatelessServerFeatures.AsyncResourceSpecificat */ public Mono removeResource(String resourceUri) { if (resourceUri == null) { - return Mono.error(new McpError("Resource URI must not be null")); + return Mono.error(new IllegalArgumentException("Resource URI must not be null")); } if (this.serverCapabilities.resources() == null) { - return Mono.error(new McpError("Server must be configured with resource capabilities")); + return Mono.error(new IllegalStateException("Server must be configured with resource capabilities")); } return Mono.defer(() -> { McpStatelessServerFeatures.AsyncResourceSpecification removed = this.resources.remove(resourceUri); if (removed != null) { logger.debug("Removed resource handler: {}", resourceUri); - return Mono.empty(); } - return Mono.error(new McpError("Resource with URI '" + resourceUri + "' not found")); + else { + logger.warn("Resource with URI '{}' not found", resourceUri); + } + return Mono.empty(); + }); + } + + /** + * Add a new resource template at runtime. + * @param resourceTemplateSpecification The resource template to add + * @return Mono that completes when clients have been notified of the change + */ + public Mono addResourceTemplate( + McpStatelessServerFeatures.AsyncResourceTemplateSpecification resourceTemplateSpecification) { + + if (this.serverCapabilities.resources() == null) { + return Mono.error(new IllegalStateException( + "Server must be configured with resource capabilities to allow adding resource templates")); + } + + return Mono.defer(() -> { + var previous = this.resourceTemplates.put(resourceTemplateSpecification.resourceTemplate().uriTemplate(), + resourceTemplateSpecification); + if (previous != null) { + logger.warn("Replace existing Resource Template with URI '{}'", + resourceTemplateSpecification.resourceTemplate().uriTemplate()); + } + else { + logger.debug("Added resource template handler: {}", + resourceTemplateSpecification.resourceTemplate().uriTemplate()); + } + return Mono.empty(); + }); + } + + /** + * List all registered resource templates. + * @return A Flux stream of all registered resource templates + */ + public Flux listResourceTemplates() { + return Flux.fromIterable(this.resourceTemplates.values()) + .map(McpStatelessServerFeatures.AsyncResourceTemplateSpecification::resourceTemplate); + } + + /** + * Remove a resource template at runtime. + * @param uriTemplate The URI template of the resource template to remove + * @return Mono that completes when clients have been notified of the change + */ + public Mono removeResourceTemplate(String uriTemplate) { + + if (this.serverCapabilities.resources() == null) { + return Mono.error(new IllegalStateException( + "Server must be configured with resource capabilities to allow removing resource templates")); + } + + return Mono.defer(() -> { + McpStatelessServerFeatures.AsyncResourceTemplateSpecification removed = this.resourceTemplates + .remove(uriTemplate); + if (removed != null) { + logger.debug("Removed resource template: {}", uriTemplate); + } + else { + logger.warn("Ignore as a Resource Template with URI '{}' not found", uriTemplate); + } + return Mono.empty(); }); } @@ -444,47 +550,52 @@ private McpStatelessRequestHandler resourcesListR } private McpStatelessRequestHandler resourceTemplateListRequestHandler() { - return (ctx, params) -> Mono.just(new McpSchema.ListResourceTemplatesResult(this.getResourceTemplates(), null)); - - } - - private List getResourceTemplates() { - var list = new ArrayList<>(this.resourceTemplates); - List resourceTemplates = this.resources.keySet() - .stream() - .filter(uri -> uri.contains("{")) - .map(uri -> { - var resource = this.resources.get(uri).resource(); - var template = new ResourceTemplate(resource.uri(), resource.name(), resource.title(), - resource.description(), resource.mimeType(), resource.annotations()); - return template; - }) - .toList(); - - list.addAll(resourceTemplates); - - return list; + return (exchange, params) -> { + var resourceList = this.resourceTemplates.values() + .stream() + .map(AsyncResourceTemplateSpecification::resourceTemplate) + .toList(); + return Mono.just(new McpSchema.ListResourceTemplatesResult(resourceList, null)); + }; } private McpStatelessRequestHandler resourcesReadRequestHandler() { return (ctx, params) -> { - McpSchema.ReadResourceRequest resourceRequest = objectMapper.convertValue(params, - new TypeReference() { - }); + McpSchema.ReadResourceRequest resourceRequest = jsonMapper.convertValue(params, new TypeRef<>() { + }); var resourceUri = resourceRequest.uri(); - McpStatelessServerFeatures.AsyncResourceSpecification specification = this.resources.values() - .stream() - .filter(resourceSpecification -> this.uriTemplateManagerFactory - .create(resourceSpecification.resource().uri()) - .matches(resourceUri)) - .findFirst() - .orElseThrow(() -> new McpError("Resource not found: " + resourceUri)); + // First try to find a static resource specification + // Static resources have exact URIs + return this.findResourceSpecification(resourceUri) + .map(spec -> spec.readHandler().apply(ctx, resourceRequest)) + .orElseGet(() -> { + // If not found, try to find a dynamic resource specification + // Dynamic resources have URI templates + return this.findResourceTemplateSpecification(resourceUri) + .map(spec -> spec.readHandler().apply(ctx, resourceRequest)) + .orElseGet(() -> Mono.error(RESOURCE_NOT_FOUND.apply(resourceUri))); + }); - return specification.readHandler().apply(ctx, resourceRequest); }; } + private Optional findResourceSpecification(String uri) { + var result = this.resources.values() + .stream() + .filter(spec -> this.uriTemplateManagerFactory.create(spec.resource().uri()).matches(uri)) + .findFirst(); + return result; + } + + private Optional findResourceTemplateSpecification( + String uri) { + return this.resourceTemplates.values() + .stream() + .filter(spec -> this.uriTemplateManagerFactory.create(spec.resourceTemplate().uriTemplate()).matches(uri)) + .findFirst(); + } + // --------------------------------------- // Prompt Management // --------------------------------------- @@ -496,26 +607,34 @@ private McpStatelessRequestHandler resourcesReadRe */ public Mono addPrompt(McpStatelessServerFeatures.AsyncPromptSpecification promptSpecification) { if (promptSpecification == null) { - return Mono.error(new McpError("Prompt specification must not be null")); + return Mono.error(new IllegalArgumentException("Prompt specification must not be null")); } if (this.serverCapabilities.prompts() == null) { - return Mono.error(new McpError("Server must be configured with prompt capabilities")); + return Mono.error(new IllegalStateException("Server must be configured with prompt capabilities")); } return Mono.defer(() -> { - McpStatelessServerFeatures.AsyncPromptSpecification specification = this.prompts - .putIfAbsent(promptSpecification.prompt().name(), promptSpecification); - if (specification != null) { - return Mono.error( - new McpError("Prompt with name '" + promptSpecification.prompt().name() + "' already exists")); + var previous = this.prompts.put(promptSpecification.prompt().name(), promptSpecification); + if (previous != null) { + logger.warn("Replace existing Prompt with name '{}'", promptSpecification.prompt().name()); + } + else { + logger.debug("Added prompt handler: {}", promptSpecification.prompt().name()); } - - logger.debug("Added prompt handler: {}", promptSpecification.prompt().name()); return Mono.empty(); }); } + /** + * List all registered prompts. + * @return A Flux stream of all registered prompts + */ + public Flux listPrompts() { + return Flux.fromIterable(this.prompts.values()) + .map(McpStatelessServerFeatures.AsyncPromptSpecification::prompt); + } + /** * Remove a prompt handler at runtime. * @param promptName The name of the prompt handler to remove @@ -523,10 +642,10 @@ public Mono addPrompt(McpStatelessServerFeatures.AsyncPromptSpecification */ public Mono removePrompt(String promptName) { if (promptName == null) { - return Mono.error(new McpError("Prompt name must not be null")); + return Mono.error(new IllegalArgumentException("Prompt name must not be null")); } if (this.serverCapabilities.prompts() == null) { - return Mono.error(new McpError("Server must be configured with prompt capabilities")); + return Mono.error(new IllegalStateException("Server must be configured with prompt capabilities")); } return Mono.defer(() -> { @@ -536,7 +655,11 @@ public Mono removePrompt(String promptName) { logger.debug("Removed prompt handler: {}", promptName); return Mono.empty(); } - return Mono.error(new McpError("Prompt with name '" + promptName + "' not found")); + else { + logger.warn("Ignore as a Prompt with name '{}' not found", promptName); + } + + return Mono.empty(); }); } @@ -558,67 +681,122 @@ private McpStatelessRequestHandler promptsListReque private McpStatelessRequestHandler promptsGetRequestHandler() { return (ctx, params) -> { - McpSchema.GetPromptRequest promptRequest = objectMapper.convertValue(params, - new TypeReference() { + McpSchema.GetPromptRequest promptRequest = jsonMapper.convertValue(params, + new TypeRef() { }); // Implement prompt retrieval logic here McpStatelessServerFeatures.AsyncPromptSpecification specification = this.prompts.get(promptRequest.name()); if (specification == null) { - return Mono.error(new McpError("Prompt not found: " + promptRequest.name())); + return Mono.error(McpError.builder(ErrorCodes.INVALID_PARAMS) + .message("Invalid prompt name") + .data("Prompt not found: " + promptRequest.name()) + .build()); } return specification.promptHandler().apply(ctx, promptRequest); }; } + private static final Mono EMPTY_COMPLETION_RESULT = Mono + .just(new McpSchema.CompleteResult(new CompleteCompletion(List.of(), 0, false))); + private McpStatelessRequestHandler completionCompleteRequestHandler() { return (ctx, params) -> { McpSchema.CompleteRequest request = parseCompletionParams(params); if (request.ref() == null) { - return Mono.error(new McpError("ref must not be null")); + return Mono.error( + McpError.builder(ErrorCodes.INVALID_PARAMS).message("Completion ref must not be null").build()); } if (request.ref().type() == null) { - return Mono.error(new McpError("type must not be null")); + return Mono.error(McpError.builder(ErrorCodes.INVALID_PARAMS) + .message("Completion ref type must not be null") + .build()); } String type = request.ref().type(); String argumentName = request.argument().name(); - // check if the referenced resource exists - if (type.equals("ref/prompt") && request.ref() instanceof McpSchema.PromptReference promptReference) { + // Check if valid a Prompt exists for this completion request + if (type.equals(PromptReference.TYPE) + && request.ref() instanceof McpSchema.PromptReference promptReference) { + McpStatelessServerFeatures.AsyncPromptSpecification promptSpec = this.prompts .get(promptReference.name()); if (promptSpec == null) { - return Mono.error(new McpError("Prompt not found: " + promptReference.name())); + return Mono.error(McpError.builder(ErrorCodes.INVALID_PARAMS) + .message("Prompt not found: " + promptReference.name()) + .build()); } - if (promptSpec.prompt().arguments().stream().noneMatch(arg -> arg.name().equals(argumentName))) { + if (!promptSpec.prompt() + .arguments() + .stream() + .filter(arg -> arg.name().equals(argumentName)) + .findFirst() + .isPresent()) { - return Mono.error(new McpError("Argument not found: " + argumentName)); + logger.warn("Argument not found: {} in prompt: {}", argumentName, promptReference.name()); + + return EMPTY_COMPLETION_RESULT; } } - if (type.equals("ref/resource") && request.ref() instanceof McpSchema.ResourceReference resourceReference) { - McpStatelessServerFeatures.AsyncResourceSpecification resourceSpec = this.resources - .get(resourceReference.uri()); - if (resourceSpec == null) { - return Mono.error(new McpError("Resource not found: " + resourceReference.uri())); - } - if (!uriTemplateManagerFactory.create(resourceSpec.resource().uri()) - .getVariableNames() - .contains(argumentName)) { - return Mono.error(new McpError("Argument not found: " + argumentName)); + // Check if valid Resource or ResourceTemplate exists for this completion + // request + if (type.equals(ResourceReference.TYPE) + && request.ref() instanceof McpSchema.ResourceReference resourceReference) { + + var uriTemplateManager = uriTemplateManagerFactory.create(resourceReference.uri()); + + if (!uriTemplateManager.isUriTemplate(resourceReference.uri())) { + // Attempting to autocomplete a fixed resource URI is not an error in + // the spec (but probably should be). + return EMPTY_COMPLETION_RESULT; } + McpStatelessServerFeatures.AsyncResourceSpecification resourceSpec = this + .findResourceSpecification(resourceReference.uri()) + .orElse(null); + + if (resourceSpec != null) { + if (!uriTemplateManagerFactory.create(resourceSpec.resource().uri()) + .getVariableNames() + .contains(argumentName)) { + + return Mono.error(McpError.builder(ErrorCodes.INVALID_PARAMS) + .message("Argument not found: " + argumentName + " in resource: " + resourceReference.uri()) + .build()); + } + } + else { + var templateSpec = this.findResourceTemplateSpecification(resourceReference.uri()).orElse(null); + if (templateSpec != null) { + + if (!uriTemplateManagerFactory.create(templateSpec.resourceTemplate().uriTemplate()) + .getVariableNames() + .contains(argumentName)) { + + return Mono.error(McpError.builder(ErrorCodes.INVALID_PARAMS) + .message("Argument not found: " + argumentName + " in resource template: " + + resourceReference.uri()) + .build()); + } + } + else { + return Mono.error(RESOURCE_NOT_FOUND.apply(resourceReference.uri())); + } + } } McpStatelessServerFeatures.AsyncCompletionSpecification specification = this.completions.get(request.ref()); if (specification == null) { - return Mono.error(new McpError("AsyncCompletionSpecification not found: " + request.ref())); + return Mono.error(McpError.builder(ErrorCodes.INVALID_PARAMS) + .message("AsyncCompletionSpecification not found: " + request.ref()) + .build()); } return specification.completionHandler().apply(ctx, request); @@ -647,9 +825,9 @@ private McpSchema.CompleteRequest parseCompletionParams(Object object) { String refType = (String) refMap.get("type"); McpSchema.CompleteReference ref = switch (refType) { - case "ref/prompt" -> new McpSchema.PromptReference(refType, (String) refMap.get("name"), + case PromptReference.TYPE -> new McpSchema.PromptReference(refType, (String) refMap.get("name"), refMap.get("title") != null ? (String) refMap.get("title") : null); - case "ref/resource" -> new McpSchema.ResourceReference(refType, (String) refMap.get("uri")); + case ResourceReference.TYPE -> new McpSchema.ResourceReference(refType, (String) refMap.get("uri")); default -> throw new IllegalArgumentException("Invalid ref type: " + refType); }; diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessNotificationHandler.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessNotificationHandler.java similarity index 91% rename from mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessNotificationHandler.java rename to mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessNotificationHandler.java index 6db79a62c..a2fabb283 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessNotificationHandler.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessNotificationHandler.java @@ -4,6 +4,7 @@ package io.modelcontextprotocol.server; +import io.modelcontextprotocol.common.McpTransportContext; import reactor.core.publisher.Mono; /** diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessRequestHandler.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessRequestHandler.java similarity index 91% rename from mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessRequestHandler.java rename to mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessRequestHandler.java index e5c9e7c09..37cd3c096 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessRequestHandler.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessRequestHandler.java @@ -4,6 +4,7 @@ package io.modelcontextprotocol.server; +import io.modelcontextprotocol.common.McpTransportContext; import reactor.core.publisher.Mono; /** diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessServerFeatures.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessServerFeatures.java similarity index 80% rename from mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessServerFeatures.java rename to mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessServerFeatures.java index 8be59a779..a15681ba5 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessServerFeatures.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessServerFeatures.java @@ -4,6 +4,13 @@ package io.modelcontextprotocol.server; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.BiFunction; + +import io.modelcontextprotocol.common.McpTransportContext; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; import io.modelcontextprotocol.util.Assert; @@ -11,12 +18,6 @@ import reactor.core.publisher.Mono; import reactor.core.scheduler.Schedulers; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.function.BiFunction; - /** * MCP stateless server features specification that a particular server can choose to * support. @@ -33,13 +34,14 @@ public class McpStatelessServerFeatures { * @param serverCapabilities The server capabilities * @param tools The list of tool specifications * @param resources The map of resource specifications - * @param resourceTemplates The list of resource templates + * @param resourceTemplates The map of resource templates * @param prompts The map of prompt specifications * @param instructions The server instructions text */ record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, List tools, - Map resources, List resourceTemplates, + Map resources, + Map resourceTemplates, Map prompts, Map completions, String instructions) { @@ -50,13 +52,14 @@ record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities s * @param serverCapabilities The server capabilities * @param tools The list of tool specifications * @param resources The map of resource specifications - * @param resourceTemplates The list of resource templates + * @param resourceTemplates The map of resource templates * @param prompts The map of prompt specifications * @param instructions The server instructions text */ Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, List tools, - Map resources, List resourceTemplates, + Map resources, + Map resourceTemplates, Map prompts, Map completions, String instructions) { @@ -67,10 +70,7 @@ record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities s this.serverCapabilities = (serverCapabilities != null) ? serverCapabilities : new McpSchema.ServerCapabilities(null, // completions null, // experimental - new McpSchema.ServerCapabilities.LoggingCapabilities(), // Enable - // logging - // by - // default + null, // currently statless server doesn't support set logging !Utils.isEmpty(prompts) ? new McpSchema.ServerCapabilities.PromptCapabilities(false) : null, !Utils.isEmpty(resources) ? new McpSchema.ServerCapabilities.ResourceCapabilities(false, false) : null, @@ -78,7 +78,7 @@ record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities s this.tools = (tools != null) ? tools : List.of(); this.resources = (resources != null) ? resources : Map.of(); - this.resourceTemplates = (resourceTemplates != null) ? resourceTemplates : List.of(); + this.resourceTemplates = (resourceTemplates != null) ? resourceTemplates : Map.of(); this.prompts = (prompts != null) ? prompts : Map.of(); this.completions = (completions != null) ? completions : Map.of(); this.instructions = instructions; @@ -105,6 +105,11 @@ static Async fromSync(Sync syncSpec, boolean immediateExecution) { resources.put(key, AsyncResourceSpecification.fromSync(resource, immediateExecution)); }); + Map resourceTemplates = new HashMap<>(); + syncSpec.resourceTemplates().forEach((key, resource) -> { + resourceTemplates.put(key, AsyncResourceTemplateSpecification.fromSync(resource, immediateExecution)); + }); + Map prompts = new HashMap<>(); syncSpec.prompts().forEach((key, prompt) -> { prompts.put(key, AsyncPromptSpecification.fromSync(prompt, immediateExecution)); @@ -115,8 +120,8 @@ static Async fromSync(Sync syncSpec, boolean immediateExecution) { completions.put(key, AsyncCompletionSpecification.fromSync(completion, immediateExecution)); }); - return new Async(syncSpec.serverInfo(), syncSpec.serverCapabilities(), tools, resources, - syncSpec.resourceTemplates(), prompts, completions, syncSpec.instructions()); + return new Async(syncSpec.serverInfo(), syncSpec.serverCapabilities(), tools, resources, resourceTemplates, + prompts, completions, syncSpec.instructions()); } } @@ -127,14 +132,14 @@ static Async fromSync(Sync syncSpec, boolean immediateExecution) { * @param serverCapabilities The server capabilities * @param tools The list of tool specifications * @param resources The map of resource specifications - * @param resourceTemplates The list of resource templates + * @param resourceTemplates The map of resource templates * @param prompts The map of prompt specifications * @param instructions The server instructions text */ record Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, List tools, Map resources, - List resourceTemplates, + Map resourceTemplates, Map prompts, Map completions, String instructions) { @@ -145,14 +150,14 @@ record Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities se * @param serverCapabilities The server capabilities * @param tools The list of tool specifications * @param resources The map of resource specifications - * @param resourceTemplates The list of resource templates + * @param resourceTemplates The map of resource templates * @param prompts The map of prompt specifications * @param instructions The server instructions text */ Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, List tools, Map resources, - List resourceTemplates, + Map resourceTemplates, Map prompts, Map completions, String instructions) { @@ -174,7 +179,7 @@ record Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities se this.tools = (tools != null) ? tools : new ArrayList<>(); this.resources = (resources != null) ? resources : new HashMap<>(); - this.resourceTemplates = (resourceTemplates != null) ? resourceTemplates : new ArrayList<>(); + this.resourceTemplates = (resourceTemplates != null) ? resourceTemplates : Map.of(); this.prompts = (prompts != null) ? prompts : new HashMap<>(); this.completions = (completions != null) ? completions : new HashMap<>(); this.instructions = instructions; @@ -298,6 +303,46 @@ static AsyncResourceSpecification fromSync(SyncResourceSpecification resource, b } } + /** + * Specification of a resource template with its synchronous handler function. + * Resource templates allow servers to expose parameterized resources using URI + * templates: URI + * templates.. Arguments may be auto-completed through the + * completion API. + * + * Templates support: + *
    + *
  • Parameterized resource definitions + *
  • Dynamic content generation + *
  • Consistent resource formatting + *
  • Contextual data injection + *
+ * + * @param resourceTemplate The resource template definition including name, + * description, and parameter schema + * @param readHandler The function that handles resource read requests. The function's + * first argument is an {@link McpTransportContext} upon which the server can interact + * with the connected client. The second arguments is a + * {@link McpSchema.ReadResourceRequest}. {@link McpSchema.ResourceTemplate} + * {@link McpSchema.ReadResourceResult} + */ + public record AsyncResourceTemplateSpecification(McpSchema.ResourceTemplate resourceTemplate, + BiFunction> readHandler) { + + static AsyncResourceTemplateSpecification fromSync(SyncResourceTemplateSpecification resource, + boolean immediateExecution) { + // FIXME: This is temporary, proper validation should be implemented + if (resource == null) { + return null; + } + return new AsyncResourceTemplateSpecification(resource.resourceTemplate(), (ctx, req) -> { + var resourceResult = Mono.fromCallable(() -> resource.readHandler().apply(ctx, req)); + return immediateExecution ? resourceResult : resourceResult.subscribeOn(Schedulers.boundedElastic()); + }); + } + } + /** * Specification of a prompt template with its asynchronous handler function. Prompts * provide structured templates for AI model interactions, supporting: @@ -448,6 +493,34 @@ public record SyncResourceSpecification(McpSchema.Resource resource, BiFunction readHandler) { } + /** + * Specification of a resource template with its synchronous handler function. + * Resource templates allow servers to expose parameterized resources using URI + * templates: URI + * templates.. Arguments may be auto-completed through the + * completion API. + * + * Templates support: + *
    + *
  • Parameterized resource definitions + *
  • Dynamic content generation + *
  • Consistent resource formatting + *
  • Contextual data injection + *
+ * + * @param resourceTemplate The resource template definition including name, + * description, and parameter schema + * @param readHandler The function that handles resource read requests. The function's + * first argument is an {@link McpTransportContext} upon which the server can interact + * with the connected client. The second arguments is a + * {@link McpSchema.ReadResourceRequest}. {@link McpSchema.ResourceTemplate} + * {@link McpSchema.ReadResourceResult} + */ + public record SyncResourceTemplateSpecification(McpSchema.ResourceTemplate resourceTemplate, + BiFunction readHandler) { + } + /** * Specification of a prompt template with its synchronous handler function. Prompts * provide structured templates for AI model interactions, supporting: diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessServerHandler.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessServerHandler.java similarity index 94% rename from mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessServerHandler.java rename to mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessServerHandler.java index 7c4e23cfc..cbae58bfd 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessServerHandler.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessServerHandler.java @@ -4,6 +4,7 @@ package io.modelcontextprotocol.server; +import io.modelcontextprotocol.common.McpTransportContext; import io.modelcontextprotocol.spec.McpSchema; import reactor.core.publisher.Mono; diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessSyncServer.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessSyncServer.java similarity index 71% rename from mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessSyncServer.java rename to mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessSyncServer.java index 0151a754b..6849eb8ed 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessSyncServer.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessSyncServer.java @@ -74,6 +74,14 @@ public void addTool(McpStatelessServerFeatures.SyncToolSpecification toolSpecifi .block(); } + /** + * List all registered tools. + * @return A list of all registered tools + */ + public List listTools() { + return this.asyncServer.listTools().collectList().block(); + } + /** * Remove a tool handler at runtime. * @param toolName The name of the tool handler to remove @@ -93,6 +101,14 @@ public void addResource(McpStatelessServerFeatures.SyncResourceSpecification res .block(); } + /** + * List all registered resources. + * @return A list of all registered resources + */ + public List listResources() { + return this.asyncServer.listResources().collectList().block(); + } + /** * Remove a resource handler at runtime. * @param resourceUri The URI of the resource handler to remove @@ -101,6 +117,34 @@ public void removeResource(String resourceUri) { this.asyncServer.removeResource(resourceUri).block(); } + /** + * Add a new resource template. + * @param resourceTemplateSpecification The resource template specification to add + */ + public void addResourceTemplate( + McpStatelessServerFeatures.SyncResourceTemplateSpecification resourceTemplateSpecification) { + this.asyncServer + .addResourceTemplate(McpStatelessServerFeatures.AsyncResourceTemplateSpecification + .fromSync(resourceTemplateSpecification, this.immediateExecution)) + .block(); + } + + /** + * List all registered resource templates. + * @return A list of all registered resource templates + */ + public List listResourceTemplates() { + return this.asyncServer.listResourceTemplates().collectList().block(); + } + + /** + * Remove a resource template. + * @param uriTemplate The URI template of the resource template to remove + */ + public void removeResourceTemplate(String uriTemplate) { + this.asyncServer.removeResourceTemplate(uriTemplate).block(); + } + /** * Add a new prompt handler at runtime. * @param promptSpecification The prompt handler to add @@ -112,6 +156,14 @@ public void addPrompt(McpStatelessServerFeatures.SyncPromptSpecification promptS .block(); } + /** + * List all registered prompts. + * @return A list of all registered prompts + */ + public List listPrompts() { + return this.asyncServer.listPrompts().collectList().block(); + } + /** * Remove a prompt handler at runtime. * @param promptName The name of the prompt handler to remove diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java similarity index 78% rename from mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java rename to mcp-core/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java index 5adda1a74..10f0e5a31 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java @@ -4,6 +4,8 @@ package io.modelcontextprotocol.server; +import java.util.List; + import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; import io.modelcontextprotocol.util.Assert; @@ -87,6 +89,14 @@ public void addTool(McpServerFeatures.SyncToolSpecification toolHandler) { .block(); } + /** + * List all registered tools. + * @return A list of all registered tools + */ + public List listTools() { + return this.asyncServer.listTools().collectList().block(); + } + /** * Remove a tool handler. * @param toolName The name of the tool handler to remove @@ -97,15 +107,23 @@ public void removeTool(String toolName) { /** * Add a new resource handler. - * @param resourceHandler The resource handler to add + * @param resourceSpecification The resource specification to add */ - public void addResource(McpServerFeatures.SyncResourceSpecification resourceHandler) { + public void addResource(McpServerFeatures.SyncResourceSpecification resourceSpecification) { this.asyncServer - .addResource( - McpServerFeatures.AsyncResourceSpecification.fromSync(resourceHandler, this.immediateExecution)) + .addResource(McpServerFeatures.AsyncResourceSpecification.fromSync(resourceSpecification, + this.immediateExecution)) .block(); } + /** + * List all registered resources. + * @return A list of all registered resources + */ + public List listResources() { + return this.asyncServer.listResources().collectList().block(); + } + /** * Remove a resource handler. * @param resourceUri The URI of the resource handler to remove @@ -114,6 +132,33 @@ public void removeResource(String resourceUri) { this.asyncServer.removeResource(resourceUri).block(); } + /** + * Add a new resource template. + * @param resourceTemplateSpecification The resource template specification to add + */ + public void addResourceTemplate(McpServerFeatures.SyncResourceTemplateSpecification resourceTemplateSpecification) { + this.asyncServer + .addResourceTemplate(McpServerFeatures.AsyncResourceTemplateSpecification + .fromSync(resourceTemplateSpecification, this.immediateExecution)) + .block(); + } + + /** + * List all registered resource templates. + * @return A list of all registered resource templates + */ + public List listResourceTemplates() { + return this.asyncServer.listResourceTemplates().collectList().block(); + } + + /** + * Remove a resource template. + * @param uriTemplate The URI template of the resource template to remove + */ + public void removeResourceTemplate(String uriTemplate) { + this.asyncServer.removeResourceTemplate(uriTemplate).block(); + } + /** * Add a new prompt handler. * @param promptSpecification The prompt specification to add @@ -125,6 +170,14 @@ public void addPrompt(McpServerFeatures.SyncPromptSpecification promptSpecificat .block(); } + /** + * List all registered prompts. + * @return A list of all registered prompts + */ + public List listPrompts() { + return this.asyncServer.listPrompts().collectList().block(); + } + /** * Remove a prompt handler. * @param promptName The name of the prompt handler to remove diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java similarity index 98% rename from mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java rename to mcp-core/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java index 5f22df5e9..0b9115b79 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java @@ -4,6 +4,7 @@ package io.modelcontextprotocol.server; +import io.modelcontextprotocol.common.McpTransportContext; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpTransportContextExtractor.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpTransportContextExtractor.java similarity index 59% rename from mcp/src/main/java/io/modelcontextprotocol/server/McpTransportContextExtractor.java rename to mcp-core/src/main/java/io/modelcontextprotocol/server/McpTransportContextExtractor.java index 97fcecf0d..ea9f05a4f 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpTransportContextExtractor.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpTransportContextExtractor.java @@ -4,6 +4,8 @@ package io.modelcontextprotocol.server; +import io.modelcontextprotocol.common.McpTransportContext; + /** * The contract for extracting metadata from a generic transport request of type * {@link T}. @@ -15,14 +17,11 @@ public interface McpTransportContextExtractor { /** - * Given an empty context, provides the means to fill it with transport-specific - * metadata extracted from the request. + * Extract transport-specific metadata from the request into an McpTransportContext. * @param request the generic representation for the request in the context of a * specific transport implementation - * @param transportContext the mutable context which can be filled in with metadata - * @return the context filled in with metadata. It can be the same instance as - * provided or a new one. + * @return the context containing the metadata */ - McpTransportContext extract(T request, McpTransportContext transportContext); + McpTransportContext extract(T request); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java similarity index 79% rename from mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java rename to mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java index ceeea31b1..96cebb74a 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java @@ -14,8 +14,10 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicBoolean; -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.TypeRef; +import io.modelcontextprotocol.server.McpTransportContextExtractor; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpServerSession; @@ -67,7 +69,9 @@ @WebServlet(asyncSupported = true) public class HttpServletSseServerTransportProvider extends HttpServlet implements McpServerTransportProvider { - /** Logger for this class */ + /** + * Logger for this class + */ private static final Logger logger = LoggerFactory.getLogger(HttpServletSseServerTransportProvider.class); public static final String UTF_8 = "UTF-8"; @@ -76,98 +80,96 @@ public class HttpServletSseServerTransportProvider extends HttpServlet implement public static final String FAILED_TO_SEND_ERROR_RESPONSE = "Failed to send error response: {}"; - /** Default endpoint path for SSE connections */ + /** + * Default endpoint path for SSE connections + */ public static final String DEFAULT_SSE_ENDPOINT = "/sse"; - /** Event type for regular messages */ + /** + * Event type for regular messages + */ public static final String MESSAGE_EVENT_TYPE = "message"; - /** Event type for endpoint information */ + /** + * Event type for endpoint information + */ public static final String ENDPOINT_EVENT_TYPE = "endpoint"; + public static final String SESSION_ID = "sessionId"; + public static final String DEFAULT_BASE_URL = ""; - /** JSON object mapper for serialization/deserialization */ - private final ObjectMapper objectMapper; + /** + * JSON mapper for serialization/deserialization + */ + private final McpJsonMapper jsonMapper; - /** Base URL for the server transport */ + /** + * Base URL for the server transport + */ private final String baseUrl; - /** The endpoint path for handling client messages */ + /** + * The endpoint path for handling client messages + */ private final String messageEndpoint; - /** The endpoint path for handling SSE connections */ + /** + * The endpoint path for handling SSE connections + */ private final String sseEndpoint; - /** Map of active client sessions, keyed by session ID */ + /** + * Map of active client sessions, keyed by session ID + */ private final Map sessions = new ConcurrentHashMap<>(); - /** Flag indicating if the transport is in the process of shutting down */ - private final AtomicBoolean isClosing = new AtomicBoolean(false); - - /** Session factory for creating new sessions */ - private McpServerSession.Factory sessionFactory; + private McpTransportContextExtractor contextExtractor; /** - * Keep-alive scheduler for managing session pings. Activated if keepAliveInterval is - * set. Disabled by default. + * Flag indicating if the transport is in the process of shutting down */ - private KeepAliveScheduler keepAliveScheduler; + private final AtomicBoolean isClosing = new AtomicBoolean(false); /** - * Creates a new HttpServletSseServerTransportProvider instance with a custom SSE - * endpoint. - * @param objectMapper The JSON object mapper to use for message - * serialization/deserialization - * @param messageEndpoint The endpoint path where clients will send their messages - * @param sseEndpoint The endpoint path where clients will establish SSE connections - * @deprecated Use the builder {@link #builder()} instead for better configuration - * options. + * Session factory for creating new sessions */ - @Deprecated - public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, - String sseEndpoint) { - this(objectMapper, DEFAULT_BASE_URL, messageEndpoint, sseEndpoint); - } + private McpServerSession.Factory sessionFactory; /** - * Creates a new HttpServletSseServerTransportProvider instance with a custom SSE - * endpoint. - * @param objectMapper The JSON object mapper to use for message - * serialization/deserialization - * @param baseUrl The base URL for the server transport - * @param messageEndpoint The endpoint path where clients will send their messages - * @param sseEndpoint The endpoint path where clients will establish SSE connections - * @deprecated Use the builder {@link #builder()} instead for better configuration - * options. + * Keep-alive scheduler for managing session pings. Activated if keepAliveInterval is + * set. Disabled by default. */ - @Deprecated - public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, - String sseEndpoint) { - this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, null); - } + private KeepAliveScheduler keepAliveScheduler; /** * Creates a new HttpServletSseServerTransportProvider instance with a custom SSE * endpoint. - * @param objectMapper The JSON object mapper to use for message + * @param jsonMapper The JSON object mapper to use for message * serialization/deserialization * @param baseUrl The base URL for the server transport * @param messageEndpoint The endpoint path where clients will send their messages * @param sseEndpoint The endpoint path where clients will establish SSE connections * @param keepAliveInterval The interval for keep-alive pings, or null to disable * keep-alive functionality + * @param contextExtractor The extractor for transport context from the request. * @deprecated Use the builder {@link #builder()} instead for better configuration * options. */ - @Deprecated - public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, - String sseEndpoint, Duration keepAliveInterval) { + private HttpServletSseServerTransportProvider(McpJsonMapper jsonMapper, String baseUrl, String messageEndpoint, + String sseEndpoint, Duration keepAliveInterval, + McpTransportContextExtractor contextExtractor) { + + Assert.notNull(jsonMapper, "JsonMapper must not be null"); + Assert.notNull(messageEndpoint, "messageEndpoint must not be null"); + Assert.notNull(sseEndpoint, "sseEndpoint must not be null"); + Assert.notNull(contextExtractor, "Context extractor must not be null"); - this.objectMapper = objectMapper; + this.jsonMapper = jsonMapper; this.baseUrl = baseUrl; this.messageEndpoint = messageEndpoint; this.sseEndpoint = sseEndpoint; + this.contextExtractor = contextExtractor; if (keepAliveInterval != null) { @@ -186,17 +188,6 @@ public List protocolVersions() { return List.of(ProtocolVersions.MCP_2024_11_05); } - /** - * Creates a new HttpServletSseServerTransportProvider instance with the default SSE - * endpoint. - * @param objectMapper The JSON object mapper to use for message - * serialization/deserialization - * @param messageEndpoint The endpoint path where clients will send their messages - */ - public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint) { - this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT); - } - /** * Sets the session factory for creating new sessions. * @param sessionFactory The session factory to use @@ -276,7 +267,22 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response) this.sessions.put(sessionId, session); // Send initial endpoint event - this.sendEvent(writer, ENDPOINT_EVENT_TYPE, this.baseUrl + this.messageEndpoint + "?sessionId=" + sessionId); + this.sendEvent(writer, ENDPOINT_EVENT_TYPE, buildEndpointUrl(sessionId)); + } + + /** + * Constructs the full message endpoint URL by combining the base URL, message path, + * and the required session_id query parameter. + * @param sessionId the unique session identifier + * @return the fully qualified endpoint URL as a string + */ + private String buildEndpointUrl(String sessionId) { + // for WebMVC compatibility + if (this.baseUrl.endsWith("/")) { + return this.baseUrl.substring(0, this.baseUrl.length() - 1) + this.messageEndpoint + "?sessionId=" + + sessionId; + } + return this.baseUrl + this.messageEndpoint + "?sessionId=" + sessionId; } /** @@ -311,7 +317,7 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) response.setContentType(APPLICATION_JSON); response.setCharacterEncoding(UTF_8); response.setStatus(HttpServletResponse.SC_BAD_REQUEST); - String jsonError = objectMapper.writeValueAsString(new McpError("Session ID missing in message endpoint")); + String jsonError = jsonMapper.writeValueAsString(new McpError("Session ID missing in message endpoint")); PrintWriter writer = response.getWriter(); writer.write(jsonError); writer.flush(); @@ -324,7 +330,7 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) response.setContentType(APPLICATION_JSON); response.setCharacterEncoding(UTF_8); response.setStatus(HttpServletResponse.SC_NOT_FOUND); - String jsonError = objectMapper.writeValueAsString(new McpError("Session not found: " + sessionId)); + String jsonError = jsonMapper.writeValueAsString(new McpError("Session not found: " + sessionId)); PrintWriter writer = response.getWriter(); writer.write(jsonError); writer.flush(); @@ -339,10 +345,12 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) body.append(line); } - McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body.toString()); + final McpTransportContext transportContext = this.contextExtractor.extract(request); + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(jsonMapper, body.toString()); // Process the message through the session's handle method - session.handle(message).block(); // Block for Servlet compatibility + // Block for Servlet compatibility + session.handle(message).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)).block(); response.setStatus(HttpServletResponse.SC_OK); } @@ -353,7 +361,7 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) response.setContentType(APPLICATION_JSON); response.setCharacterEncoding(UTF_8); response.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR); - String jsonError = objectMapper.writeValueAsString(mcpError); + String jsonError = jsonMapper.writeValueAsString(mcpError); PrintWriter writer = response.getWriter(); writer.write(jsonError); writer.flush(); @@ -449,7 +457,7 @@ private class HttpServletMcpSessionTransport implements McpServerTransport { public Mono sendMessage(McpSchema.JSONRPCMessage message) { return Mono.fromRunnable(() -> { try { - String jsonText = objectMapper.writeValueAsString(message); + String jsonText = jsonMapper.writeValueAsString(message); sendEvent(writer, MESSAGE_EVENT_TYPE, jsonText); logger.debug("Message sent to session {}", sessionId); } @@ -462,15 +470,15 @@ public Mono sendMessage(McpSchema.JSONRPCMessage message) { } /** - * Converts data from one type to another using the configured ObjectMapper. + * Converts data from one type to another using the configured JsonMapper. * @param data The source data object to convert * @param typeRef The target type reference - * @return The converted object of type T * @param The target type + * @return The converted object of type T */ @Override - public T unmarshalFrom(Object data, TypeReference typeRef) { - return objectMapper.convertValue(data, typeRef); + public T unmarshalFrom(Object data, TypeRef typeRef) { + return jsonMapper.convertValue(data, typeRef); } /** @@ -526,7 +534,7 @@ public static Builder builder() { */ public static class Builder { - private ObjectMapper objectMapper = new ObjectMapper(); + private McpJsonMapper jsonMapper; private String baseUrl = DEFAULT_BASE_URL; @@ -534,16 +542,21 @@ public static class Builder { private String sseEndpoint = DEFAULT_SSE_ENDPOINT; + private McpTransportContextExtractor contextExtractor = ( + serverRequest) -> McpTransportContext.EMPTY; + private Duration keepAliveInterval; /** - * Sets the JSON object mapper to use for message serialization/deserialization. - * @param objectMapper The object mapper to use + * Sets the JsonMapper implementation to use for serialization/deserialization. If + * not specified, a JacksonJsonMapper will be created from the configured + * ObjectMapper. + * @param jsonMapper The JsonMapper to use * @return This builder instance for method chaining */ - public Builder objectMapper(ObjectMapper objectMapper) { - Assert.notNull(objectMapper, "ObjectMapper must not be null"); - this.objectMapper = objectMapper; + public Builder jsonMapper(McpJsonMapper jsonMapper) { + Assert.notNull(jsonMapper, "JsonMapper must not be null"); + this.jsonMapper = jsonMapper; return this; } @@ -583,6 +596,19 @@ public Builder sseEndpoint(String sseEndpoint) { return this; } + /** + * Sets the context extractor for extracting transport context from the request. + * @param contextExtractor The context extractor to use. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if contextExtractor is null + */ + public HttpServletSseServerTransportProvider.Builder contextExtractor( + McpTransportContextExtractor contextExtractor) { + Assert.notNull(contextExtractor, "Context extractor must not be null"); + this.contextExtractor = contextExtractor; + return this; + } + /** * Sets the interval for keep-alive pings. *

@@ -599,17 +625,15 @@ public Builder keepAliveInterval(Duration keepAliveInterval) { * Builds a new instance of HttpServletSseServerTransportProvider with the * configured settings. * @return A new HttpServletSseServerTransportProvider instance - * @throws IllegalStateException if objectMapper or messageEndpoint is not set + * @throws IllegalStateException if jsonMapper or messageEndpoint is not set */ public HttpServletSseServerTransportProvider build() { - if (objectMapper == null) { - throw new IllegalStateException("ObjectMapper must be set"); - } if (messageEndpoint == null) { throw new IllegalStateException("MessageEndpoint must be set"); } - return new HttpServletSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint, - keepAliveInterval); + return new HttpServletSseServerTransportProvider( + jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, baseUrl, messageEndpoint, sseEndpoint, + keepAliveInterval, contextExtractor); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStatelessServerTransport.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStatelessServerTransport.java similarity index 88% rename from mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStatelessServerTransport.java rename to mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStatelessServerTransport.java index 25b003564..40767f416 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStatelessServerTransport.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStatelessServerTransport.java @@ -11,11 +11,10 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.json.McpJsonMapper; -import io.modelcontextprotocol.server.DefaultMcpTransportContext; +import io.modelcontextprotocol.common.McpTransportContext; import io.modelcontextprotocol.server.McpStatelessServerHandler; -import io.modelcontextprotocol.server.McpTransportContext; import io.modelcontextprotocol.server.McpTransportContextExtractor; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; @@ -49,7 +48,7 @@ public class HttpServletStatelessServerTransport extends HttpServlet implements public static final String FAILED_TO_SEND_ERROR_RESPONSE = "Failed to send error response: {}"; - private final ObjectMapper objectMapper; + private final McpJsonMapper jsonMapper; private final String mcpEndpoint; @@ -59,13 +58,13 @@ public class HttpServletStatelessServerTransport extends HttpServlet implements private volatile boolean isClosing = false; - private HttpServletStatelessServerTransport(ObjectMapper objectMapper, String mcpEndpoint, + private HttpServletStatelessServerTransport(McpJsonMapper jsonMapper, String mcpEndpoint, McpTransportContextExtractor contextExtractor) { - Assert.notNull(objectMapper, "objectMapper must not be null"); + Assert.notNull(jsonMapper, "jsonMapper must not be null"); Assert.notNull(mcpEndpoint, "mcpEndpoint must not be null"); Assert.notNull(contextExtractor, "contextExtractor must not be null"); - this.objectMapper = objectMapper; + this.jsonMapper = jsonMapper; this.mcpEndpoint = mcpEndpoint; this.contextExtractor = contextExtractor; } @@ -123,7 +122,7 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) return; } - McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext()); + McpTransportContext transportContext = this.contextExtractor.extract(request); String accept = request.getHeader(ACCEPT); if (accept == null || !(accept.contains(APPLICATION_JSON) && accept.contains(TEXT_EVENT_STREAM))) { @@ -140,7 +139,7 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) body.append(line); } - McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body.toString()); + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(jsonMapper, body.toString()); if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest) { try { @@ -153,7 +152,7 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) response.setCharacterEncoding(UTF_8); response.setStatus(HttpServletResponse.SC_OK); - String jsonResponseText = objectMapper.writeValueAsString(jsonrpcResponse); + String jsonResponseText = jsonMapper.writeValueAsString(jsonrpcResponse); PrintWriter writer = response.getWriter(); writer.write(jsonResponseText); writer.flush(); @@ -204,7 +203,7 @@ private void responseError(HttpServletResponse response, int httpCode, McpError response.setContentType(APPLICATION_JSON); response.setCharacterEncoding(UTF_8); response.setStatus(httpCode); - String jsonError = objectMapper.writeValueAsString(mcpError); + String jsonError = jsonMapper.writeValueAsString(mcpError); PrintWriter writer = response.getWriter(); writer.write(jsonError); writer.flush(); @@ -237,26 +236,27 @@ public static Builder builder() { */ public static class Builder { - private ObjectMapper objectMapper; + private McpJsonMapper jsonMapper; private String mcpEndpoint = "/mcp"; - private McpTransportContextExtractor contextExtractor = (serverRequest, context) -> context; + private McpTransportContextExtractor contextExtractor = ( + serverRequest) -> McpTransportContext.EMPTY; private Builder() { // used by a static method } /** - * Sets the ObjectMapper to use for JSON serialization/deserialization of MCP + * Sets the JsonMapper to use for JSON serialization/deserialization of MCP * messages. - * @param objectMapper The ObjectMapper instance. Must not be null. + * @param jsonMapper The JsonMapper instance. Must not be null. * @return this builder instance - * @throws IllegalArgumentException if objectMapper is null + * @throws IllegalArgumentException if jsonMapper is null */ - public Builder objectMapper(ObjectMapper objectMapper) { - Assert.notNull(objectMapper, "ObjectMapper must not be null"); - this.objectMapper = objectMapper; + public Builder jsonMapper(McpJsonMapper jsonMapper) { + Assert.notNull(jsonMapper, "JsonMapper must not be null"); + this.jsonMapper = jsonMapper; return this; } @@ -295,10 +295,9 @@ public Builder contextExtractor(McpTransportContextExtractor * @throws IllegalStateException if required parameters are not set */ public HttpServletStatelessServerTransport build() { - Assert.notNull(objectMapper, "ObjectMapper must be set"); Assert.notNull(mcpEndpoint, "Message endpoint must be set"); - - return new HttpServletStatelessServerTransport(objectMapper, mcpEndpoint, contextExtractor); + return new HttpServletStatelessServerTransport(jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, + mcpEndpoint, contextExtractor); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java similarity index 93% rename from mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java rename to mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java index 8b95ec607..34671c105 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java @@ -16,11 +16,9 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.json.TypeRef; -import io.modelcontextprotocol.server.DefaultMcpTransportContext; -import io.modelcontextprotocol.server.McpTransportContext; +import io.modelcontextprotocol.common.McpTransportContext; import io.modelcontextprotocol.server.McpTransportContextExtractor; import io.modelcontextprotocol.spec.HttpHeaders; import io.modelcontextprotocol.spec.McpError; @@ -30,6 +28,7 @@ import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; import io.modelcontextprotocol.spec.ProtocolVersions; import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.json.McpJsonMapper; import io.modelcontextprotocol.util.KeepAliveScheduler; import jakarta.servlet.AsyncContext; import jakarta.servlet.ServletException; @@ -98,7 +97,7 @@ public class HttpServletStreamableServerTransportProvider extends HttpServlet */ private final boolean disallowDelete; - private final ObjectMapper objectMapper; + private final McpJsonMapper jsonMapper; private McpStreamableServerSession.Factory sessionFactory; @@ -122,22 +121,22 @@ public class HttpServletStreamableServerTransportProvider extends HttpServlet /** * Constructs a new HttpServletStreamableServerTransportProvider instance. - * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization - * of messages. + * @param jsonMapper The JsonMapper to use for JSON serialization/deserialization of + * messages. * @param mcpEndpoint The endpoint URI where clients should send their JSON-RPC * messages via HTTP. This endpoint will handle GET, POST, and DELETE requests. * @param disallowDelete Whether to disallow DELETE requests on the endpoint. * @param contextExtractor The extractor for transport context from the request. * @throws IllegalArgumentException if any parameter is null */ - private HttpServletStreamableServerTransportProvider(ObjectMapper objectMapper, String mcpEndpoint, + private HttpServletStreamableServerTransportProvider(McpJsonMapper jsonMapper, String mcpEndpoint, boolean disallowDelete, McpTransportContextExtractor contextExtractor, Duration keepAliveInterval) { - Assert.notNull(objectMapper, "ObjectMapper must not be null"); + Assert.notNull(jsonMapper, "JsonMapper must not be null"); Assert.notNull(mcpEndpoint, "MCP endpoint must not be null"); Assert.notNull(contextExtractor, "Context extractor must not be null"); - this.objectMapper = objectMapper; + this.jsonMapper = jsonMapper; this.mcpEndpoint = mcpEndpoint; this.disallowDelete = disallowDelete; this.contextExtractor = contextExtractor; @@ -157,7 +156,8 @@ private HttpServletStreamableServerTransportProvider(ObjectMapper objectMapper, @Override public List protocolVersions() { - return List.of(ProtocolVersions.MCP_2024_11_05, ProtocolVersions.MCP_2025_03_26); + return List.of(ProtocolVersions.MCP_2024_11_05, ProtocolVersions.MCP_2025_03_26, + ProtocolVersions.MCP_2025_06_18); } @Override @@ -274,7 +274,7 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response) logger.debug("Handling GET request for session: {}", sessionId); - McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext()); + McpTransportContext transportContext = this.contextExtractor.extract(request); try { response.setContentType(TEXT_EVENT_STREAM); @@ -383,7 +383,7 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) badRequestErrors.add("application/json required in Accept header"); } - McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext()); + McpTransportContext transportContext = this.contextExtractor.extract(request); try { BufferedReader reader = request.getReader(); @@ -393,7 +393,7 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) body.append(line); } - McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body.toString()); + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(jsonMapper, body.toString()); // Handle initialization request if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest @@ -404,8 +404,8 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) return; } - McpSchema.InitializeRequest initializeRequest = objectMapper.convertValue(jsonrpcRequest.params(), - new TypeReference() { + McpSchema.InitializeRequest initializeRequest = jsonMapper.convertValue(jsonrpcRequest.params(), + new TypeRef() { }); McpStreamableServerSession.McpStreamableServerSessionInit init = this.sessionFactory .startSession(initializeRequest); @@ -419,7 +419,7 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) response.setHeader(HttpHeaders.MCP_SESSION_ID, init.session().getId()); response.setStatus(HttpServletResponse.SC_OK); - String jsonResponse = objectMapper.writeValueAsString(new McpSchema.JSONRPCResponse( + String jsonResponse = jsonMapper.writeValueAsString(new McpSchema.JSONRPCResponse( McpSchema.JSONRPC_VERSION, jsonrpcRequest.id(), initResult, null)); PrintWriter writer = response.getWriter(); @@ -541,7 +541,7 @@ protected void doDelete(HttpServletRequest request, HttpServletResponse response return; } - McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext()); + McpTransportContext transportContext = this.contextExtractor.extract(request); if (request.getHeader(HttpHeaders.MCP_SESSION_ID) == null) { this.responseError(response, HttpServletResponse.SC_BAD_REQUEST, @@ -579,7 +579,7 @@ public void responseError(HttpServletResponse response, int httpCode, McpError m response.setContentType(APPLICATION_JSON); response.setCharacterEncoding(UTF_8); response.setStatus(httpCode); - String jsonError = objectMapper.writeValueAsString(mcpError); + String jsonError = jsonMapper.writeValueAsString(mcpError); PrintWriter writer = response.getWriter(); writer.write(jsonError); writer.flush(); @@ -686,7 +686,7 @@ public Mono sendMessage(McpSchema.JSONRPCMessage message, String messageId return; } - String jsonText = objectMapper.writeValueAsString(message); + String jsonText = jsonMapper.writeValueAsString(message); HttpServletStreamableServerTransportProvider.this.sendEvent(writer, MESSAGE_EVENT_TYPE, jsonText, messageId != null ? messageId : this.sessionId); logger.debug("Message sent to session {} with ID {}", this.sessionId, messageId); @@ -703,15 +703,15 @@ public Mono sendMessage(McpSchema.JSONRPCMessage message, String messageId } /** - * Converts data from one type to another using the configured ObjectMapper. + * Converts data from one type to another using the configured JsonMapper. * @param data The source data object to convert * @param typeRef The target type reference * @return The converted object of type T * @param The target type */ @Override - public T unmarshalFrom(Object data, TypeReference typeRef) { - return objectMapper.convertValue(data, typeRef); + public T unmarshalFrom(Object data, TypeRef typeRef) { + return jsonMapper.convertValue(data, typeRef); } /** @@ -763,26 +763,27 @@ public static Builder builder() { */ public static class Builder { - private ObjectMapper objectMapper; + private McpJsonMapper jsonMapper; private String mcpEndpoint = "/mcp"; private boolean disallowDelete = false; - private McpTransportContextExtractor contextExtractor = (serverRequest, context) -> context; + private McpTransportContextExtractor contextExtractor = ( + serverRequest) -> McpTransportContext.EMPTY; private Duration keepAliveInterval; /** - * Sets the ObjectMapper to use for JSON serialization/deserialization of MCP + * Sets the JsonMapper to use for JSON serialization/deserialization of MCP * messages. - * @param objectMapper The ObjectMapper instance. Must not be null. + * @param jsonMapper The JsonMapper instance. Must not be null. * @return this builder instance - * @throws IllegalArgumentException if objectMapper is null + * @throws IllegalArgumentException if JsonMapper is null */ - public Builder objectMapper(ObjectMapper objectMapper) { - Assert.notNull(objectMapper, "ObjectMapper must not be null"); - this.objectMapper = objectMapper; + public Builder jsonMapper(McpJsonMapper jsonMapper) { + Assert.notNull(jsonMapper, "JsonMapper must not be null"); + this.jsonMapper = jsonMapper; return this; } @@ -839,11 +840,10 @@ public Builder keepAliveInterval(Duration keepAliveInterval) { * @throws IllegalStateException if required parameters are not set */ public HttpServletStreamableServerTransportProvider build() { - Assert.notNull(this.objectMapper, "ObjectMapper must be set"); Assert.notNull(this.mcpEndpoint, "MCP endpoint must be set"); - - return new HttpServletStreamableServerTransportProvider(this.objectMapper, this.mcpEndpoint, - this.disallowDelete, this.contextExtractor, this.keepAliveInterval); + return new HttpServletStreamableServerTransportProvider( + jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, mcpEndpoint, disallowDelete, + contextExtractor, keepAliveInterval); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java similarity index 89% rename from mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java rename to mcp-core/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java index af602f610..68be62931 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java @@ -15,8 +15,7 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Function; -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.json.TypeRef; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; @@ -25,6 +24,7 @@ import io.modelcontextprotocol.spec.McpServerTransportProvider; import io.modelcontextprotocol.spec.ProtocolVersions; import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.json.McpJsonMapper; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; @@ -44,7 +44,7 @@ public class StdioServerTransportProvider implements McpServerTransportProvider private static final Logger logger = LoggerFactory.getLogger(StdioServerTransportProvider.class); - private final ObjectMapper objectMapper; + private final McpJsonMapper jsonMapper; private final InputStream inputStream; @@ -56,36 +56,28 @@ public class StdioServerTransportProvider implements McpServerTransportProvider private final Sinks.One inboundReady = Sinks.one(); - /** - * Creates a new StdioServerTransportProvider with a default ObjectMapper and System - * streams. - */ - public StdioServerTransportProvider() { - this(new ObjectMapper()); - } - /** * Creates a new StdioServerTransportProvider with the specified ObjectMapper and * System streams. - * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + * @param jsonMapper The JsonMapper to use for JSON serialization/deserialization */ - public StdioServerTransportProvider(ObjectMapper objectMapper) { - this(objectMapper, System.in, System.out); + public StdioServerTransportProvider(McpJsonMapper jsonMapper) { + this(jsonMapper, System.in, System.out); } /** * Creates a new StdioServerTransportProvider with the specified ObjectMapper and * streams. - * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + * @param jsonMapper The JsonMapper to use for JSON serialization/deserialization * @param inputStream The input stream to read from * @param outputStream The output stream to write to */ - public StdioServerTransportProvider(ObjectMapper objectMapper, InputStream inputStream, OutputStream outputStream) { - Assert.notNull(objectMapper, "The ObjectMapper can not be null"); + public StdioServerTransportProvider(McpJsonMapper jsonMapper, InputStream inputStream, OutputStream outputStream) { + Assert.notNull(jsonMapper, "The JsonMapper can not be null"); Assert.notNull(inputStream, "The InputStream can not be null"); Assert.notNull(outputStream, "The OutputStream can not be null"); - this.objectMapper = objectMapper; + this.jsonMapper = jsonMapper; this.inputStream = inputStream; this.outputStream = outputStream; } @@ -165,8 +157,8 @@ public Mono sendMessage(McpSchema.JSONRPCMessage message) { } @Override - public T unmarshalFrom(Object data, TypeReference typeRef) { - return objectMapper.convertValue(data, typeRef); + public T unmarshalFrom(Object data, TypeRef typeRef) { + return jsonMapper.convertValue(data, typeRef); } @Override @@ -219,7 +211,7 @@ private void startInboundProcessing() { logger.debug("Received JSON message: {}", line); try { - McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(jsonMapper, line); if (!this.inboundSink.tryEmitNext(message).isSuccess()) { // logIfNotClosing("Failed to enqueue message"); @@ -263,7 +255,7 @@ private void startOutboundProcessing() { .handle((message, sink) -> { if (message != null && !isClosing.get()) { try { - String jsonMessage = objectMapper.writeValueAsString(message); + String jsonMessage = jsonMapper.writeValueAsString(message); // Escape any embedded newlines in the JSON message as per spec jsonMessage = jsonMessage.replace("\r\n", "\\n").replace("\n", "\\n").replace("\r", "\\n"); diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/ClosedMcpTransportSession.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/ClosedMcpTransportSession.java new file mode 100644 index 000000000..b18364abb --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/ClosedMcpTransportSession.java @@ -0,0 +1,58 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ +package io.modelcontextprotocol.spec; + +import java.util.Optional; + +import org.reactivestreams.Publisher; +import reactor.core.publisher.Mono; +import reactor.util.annotation.Nullable; + +/** + * Represents a closed MCP session, which may not be reused. All calls will throw a + * {@link McpTransportSessionClosedException}. + * + * @param the resource representing the connection that the transport + * manages. + * @author Daniel Garnier-Moiroux + */ +public class ClosedMcpTransportSession implements McpTransportSession { + + private final String sessionId; + + public ClosedMcpTransportSession(@Nullable String sessionId) { + this.sessionId = sessionId; + } + + @Override + public Optional sessionId() { + throw new McpTransportSessionClosedException(sessionId); + } + + @Override + public boolean markInitialized(String sessionId) { + throw new McpTransportSessionClosedException(sessionId); + } + + @Override + public void addConnection(CONNECTION connection) { + throw new McpTransportSessionClosedException(sessionId); + } + + @Override + public void removeConnection(CONNECTION connection) { + throw new McpTransportSessionClosedException(sessionId); + } + + @Override + public void close() { + + } + + @Override + public Publisher closeGracefully() { + return Mono.empty(); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpStreamableServerSessionFactory.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/DefaultMcpStreamableServerSessionFactory.java similarity index 100% rename from mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpStreamableServerSessionFactory.java rename to mcp-core/src/main/java/io/modelcontextprotocol/spec/DefaultMcpStreamableServerSessionFactory.java diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportSession.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportSession.java similarity index 100% rename from mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportSession.java rename to mcp-core/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportSession.java diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportStream.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportStream.java similarity index 100% rename from mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportStream.java rename to mcp-core/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportStream.java diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/HttpHeaders.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/HttpHeaders.java new file mode 100644 index 000000000..6afc2c119 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/HttpHeaders.java @@ -0,0 +1,56 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.spec; + +/** + * Names of HTTP headers in use by MCP HTTP transports. + * + * @author Dariusz Jędrzejczyk + */ +public interface HttpHeaders { + + /** + * Identifies individual MCP sessions. + */ + String MCP_SESSION_ID = "Mcp-Session-Id"; + + /** + * Identifies events within an SSE Stream. + */ + String LAST_EVENT_ID = "Last-Event-ID"; + + /** + * Identifies the MCP protocol version. + */ + String PROTOCOL_VERSION = "MCP-Protocol-Version"; + + /** + * The HTTP Content-Length header. + * @see RFC9110 + */ + String CONTENT_LENGTH = "Content-Length"; + + /** + * The HTTP Content-Type header. + * @see RFC9110 + */ + String CONTENT_TYPE = "Content-Type"; + + /** + * The HTTP Accept header. + * @see RFC9110 + */ + String ACCEPT = "Accept"; + + /** + * The HTTP Cache-Control header. + * @see RFC9111 + */ + String CACHE_CONTROL = "Cache-Control"; + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/JsonSchemaValidator.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/JsonSchemaValidator.java similarity index 93% rename from mcp/src/main/java/io/modelcontextprotocol/spec/JsonSchemaValidator.java rename to mcp-core/src/main/java/io/modelcontextprotocol/spec/JsonSchemaValidator.java index 572d7c043..4a42c9ff3 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/JsonSchemaValidator.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/JsonSchemaValidator.java @@ -40,6 +40,6 @@ public static ValidationResponse asInvalid(String message) { * @return A ValidationResponse indicating whether the validation was successful or * not. */ - ValidationResponse validate(Map schema, Map structuredContent); + ValidationResponse validate(Map schema, Object structuredContent); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java similarity index 91% rename from mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java rename to mcp-core/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java index f7db3d7aa..0ba7ab3b8 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java @@ -4,7 +4,7 @@ package io.modelcontextprotocol.spec; -import com.fasterxml.jackson.core.type.TypeReference; +import io.modelcontextprotocol.json.TypeRef; import io.modelcontextprotocol.util.Assert; import org.reactivestreams.Publisher; import org.slf4j.Logger; @@ -35,6 +35,7 @@ * * @author Christian Tzolov * @author Dariusz Jędrzejczyk + * @author Yanming Zhou */ public class McpClientSession implements McpSession { @@ -146,21 +147,34 @@ private void dismissPendingResponses() { private void handle(McpSchema.JSONRPCMessage message) { if (message instanceof McpSchema.JSONRPCResponse response) { - logger.debug("Received Response: {}", response); - var sink = pendingResponses.remove(response.id()); - if (sink == null) { - logger.warn("Unexpected response for unknown id {}", response.id()); + logger.debug("Received response: {}", response); + if (response.id() != null) { + var sink = pendingResponses.remove(response.id()); + if (sink == null) { + logger.warn("Unexpected response for unknown id {}", response.id()); + } + else { + sink.success(response); + } } else { - sink.success(response); + logger.error("Discarded MCP request response without session id. " + + "This is an indication of a bug in the request sender code that can lead to memory " + + "leaks as pending requests will never be completed."); } } else if (message instanceof McpSchema.JSONRPCRequest request) { logger.debug("Received request: {}", request); handleIncomingRequest(request).onErrorResume(error -> { + + McpSchema.JSONRPCResponse.JSONRPCError jsonRpcError = (error instanceof McpError mcpError + && mcpError.getJsonRpcError() != null) ? mcpError.getJsonRpcError() + // TODO: add error message through the data field + : new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, + error.getMessage(), McpError.aggregateExceptionMessages(error)); + var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, - new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, - error.getMessage(), null)); + jsonRpcError); return Mono.just(errorResponse); }).flatMap(this.transport::sendMessage).onErrorComplete(t -> { logger.warn("Issue sending response to the client, ", t); @@ -246,7 +260,7 @@ private String generateRequestId() { * @return A Mono containing the response */ @Override - public Mono sendRequest(String method, Object requestParams, TypeReference typeRef) { + public Mono sendRequest(String method, Object requestParams, TypeRef typeRef) { String requestId = this.generateRequestId(); return Mono.deferContextual(ctx -> Mono.create(pendingResponseSink -> { diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java similarity index 100% rename from mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java rename to mcp-core/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpError.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpError.java new file mode 100644 index 000000000..d6e549fdc --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpError.java @@ -0,0 +1,116 @@ +/* +* Copyright 2024 - 2024 the original author or authors. +*/ + +package io.modelcontextprotocol.spec; + +import io.modelcontextprotocol.spec.McpSchema.JSONRPCResponse.JSONRPCError; +import io.modelcontextprotocol.util.Assert; + +import java.util.Map; +import java.util.function.Function; + +public class McpError extends RuntimeException { + + /** + * Resource + * Error Handling + */ + public static final Function RESOURCE_NOT_FOUND = resourceUri -> new McpError(new JSONRPCError( + McpSchema.ErrorCodes.RESOURCE_NOT_FOUND, "Resource not found", Map.of("uri", resourceUri))); + + private JSONRPCError jsonRpcError; + + public McpError(JSONRPCError jsonRpcError) { + super(jsonRpcError.message()); + this.jsonRpcError = jsonRpcError; + } + + @Deprecated + public McpError(Object error) { + super(error.toString()); + } + + public JSONRPCError getJsonRpcError() { + return jsonRpcError; + } + + @Override + public String toString() { + var builder = new StringBuilder(super.toString()); + if (jsonRpcError != null) { + builder.append("\n"); + builder.append(jsonRpcError.toString()); + } + return builder.toString(); + } + + public static Builder builder(int errorCode) { + return new Builder(errorCode); + } + + public static class Builder { + + private final int code; + + private String message; + + private Object data; + + private Builder(int code) { + this.code = code; + } + + public Builder message(String message) { + this.message = message; + return this; + } + + public Builder data(Object data) { + this.data = data; + return this; + } + + public McpError build() { + Assert.hasText(message, "message must not be empty"); + return new McpError(new JSONRPCError(code, message, data)); + } + + } + + public static Throwable findRootCause(Throwable throwable) { + Assert.notNull(throwable, "throwable must not be null"); + Throwable rootCause = throwable; + while (rootCause.getCause() != null && rootCause.getCause() != rootCause) { + rootCause = rootCause.getCause(); + } + return rootCause; + } + + public static String aggregateExceptionMessages(Throwable throwable) { + Assert.notNull(throwable, "throwable must not be null"); + + StringBuilder messages = new StringBuilder(); + Throwable current = throwable; + + while (current != null) { + if (messages.length() > 0) { + messages.append("\n Caused by: "); + } + + messages.append(current.getClass().getSimpleName()); + if (current.getMessage() != null) { + messages.append(": ").append(current.getMessage()); + } + + if (current.getCause() == current) { + break; + } + current = current.getCause(); + } + + return messages.toString(); + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpLoggableSession.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpLoggableSession.java similarity index 100% rename from mcp/src/main/java/io/modelcontextprotocol/spec/McpLoggableSession.java rename to mcp-core/src/main/java/io/modelcontextprotocol/spec/McpLoggableSession.java diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpSchema.java similarity index 90% rename from mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java rename to mcp-core/src/main/java/io/modelcontextprotocol/spec/McpSchema.java index 8a109a8d1..b58f1c552 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpSchema.java @@ -11,20 +11,17 @@ import java.util.List; import java.util.Map; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonSubTypes; import com.fasterxml.jackson.annotation.JsonTypeInfo; -import com.fasterxml.jackson.annotation.JsonTypeInfo.As; -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; - +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.TypeRef; import io.modelcontextprotocol.util.Assert; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * Based on the JSON-RPC 2.0 @@ -45,7 +42,7 @@ private McpSchema() { } @Deprecated - public static final String LATEST_PROTOCOL_VERSION = ProtocolVersions.MCP_2025_03_26; + public static final String LATEST_PROTOCOL_VERSION = ProtocolVersions.MCP_2025_06_18; public static final String JSONRPC_VERSION = "2.0"; @@ -111,8 +108,6 @@ private McpSchema() { // Elicitation Methods public static final String METHOD_ELICITATION_CREATE = "elicitation/create"; - private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); - // --------------------------- // JSON-RPC Error Codes // --------------------------- @@ -146,44 +141,58 @@ public static final class ErrorCodes { */ public static final int INTERNAL_ERROR = -32603; + /** + * Resource not found. + */ + public static final int RESOURCE_NOT_FOUND = -32002; + } - public sealed interface Request - permits InitializeRequest, CallToolRequest, CreateMessageRequest, ElicitRequest, CompleteRequest, - GetPromptRequest, ReadResourceRequest, SubscribeRequest, UnsubscribeRequest, PaginatedRequest { + /** + * Base interface for MCP objects that include optional metadata in the `_meta` field. + */ + public interface Meta { + /** + * @see Specification + * for notes on _meta usage + * @return additional metadata related to this resource. + */ Map meta(); - default String progressToken() { + } + + public sealed interface Request extends Meta + permits InitializeRequest, CallToolRequest, CreateMessageRequest, ElicitRequest, CompleteRequest, + GetPromptRequest, ReadResourceRequest, SubscribeRequest, UnsubscribeRequest, PaginatedRequest { + + default Object progressToken() { if (meta() != null && meta().containsKey("progressToken")) { - return meta().get("progressToken").toString(); + return meta().get("progressToken"); } return null; } } - public sealed interface Result permits InitializeResult, ListResourcesResult, ListResourceTemplatesResult, - ReadResourceResult, ListPromptsResult, GetPromptResult, ListToolsResult, CallToolResult, - CreateMessageResult, ElicitResult, CompleteResult, ListRootsResult { - - Map meta(); + public sealed interface Result extends Meta permits InitializeResult, ListResourcesResult, + ListResourceTemplatesResult, ReadResourceResult, ListPromptsResult, GetPromptResult, ListToolsResult, + CallToolResult, CreateMessageResult, ElicitResult, CompleteResult, ListRootsResult { } - public sealed interface Notification + public sealed interface Notification extends Meta permits ProgressNotification, LoggingMessageNotification, ResourcesUpdatedNotification { - Map meta(); - } - private static final TypeReference> MAP_TYPE_REF = new TypeReference<>() { + private static final TypeRef> MAP_TYPE_REF = new TypeRef<>() { }; /** * Deserializes a JSON string into a JSONRPCMessage object. - * @param objectMapper The ObjectMapper instance to use for deserialization + * @param jsonMapper The JsonMapper instance to use for deserialization * @param jsonText The JSON string to deserialize * @return A JSONRPCMessage instance using either the {@link JSONRPCRequest}, * {@link JSONRPCNotification}, or {@link JSONRPCResponse} classes. @@ -191,22 +200,22 @@ public sealed interface Notification * @throws IllegalArgumentException If the JSON structure doesn't match any known * message type */ - public static JSONRPCMessage deserializeJsonRpcMessage(ObjectMapper objectMapper, String jsonText) + public static JSONRPCMessage deserializeJsonRpcMessage(McpJsonMapper jsonMapper, String jsonText) throws IOException { logger.debug("Received JSON message: {}", jsonText); - var map = objectMapper.readValue(jsonText, MAP_TYPE_REF); + var map = jsonMapper.readValue(jsonText, MAP_TYPE_REF); // Determine message type based on specific JSON structure if (map.containsKey("method") && map.containsKey("id")) { - return objectMapper.convertValue(map, JSONRPCRequest.class); + return jsonMapper.convertValue(map, JSONRPCRequest.class); } else if (map.containsKey("method") && !map.containsKey("id")) { - return objectMapper.convertValue(map, JSONRPCNotification.class); + return jsonMapper.convertValue(map, JSONRPCNotification.class); } else if (map.containsKey("result") || map.containsKey("error")) { - return objectMapper.convertValue(map, JSONRPCResponse.class); + return jsonMapper.convertValue(map, JSONRPCResponse.class); } throw new IllegalArgumentException("Cannot deserialize JSONRPCMessage: " + jsonText); @@ -268,12 +277,12 @@ public record JSONRPCNotification( // @formatter:off } /** - * A successful (non-error) response to a request. + * A response to a request (successful, or error). * * @param jsonrpc The JSON-RPC version (must be "2.0") * @param id The request identifier that this response corresponds to - * @param result The result of the successful request - * @param error Error information if the request failed + * @param result The result of the successful request; null if error + * @param error Error information if the request failed; null if has result */ @JsonInclude(JsonInclude.Include.NON_ABSENT) @JsonIgnoreProperties(ignoreUnknown = true) @@ -297,7 +306,7 @@ public record JSONRPCResponse( // @formatter:off @JsonInclude(JsonInclude.Include.NON_ABSENT) @JsonIgnoreProperties(ignoreUnknown = true) public record JSONRPCError( // @formatter:off - @JsonProperty("code") int code, + @JsonProperty("code") Integer code, @JsonProperty("message") String message, @JsonProperty("data") Object data) { // @formatter:on } @@ -407,9 +416,47 @@ public record Sampling() { * maintain control over user interactions and data sharing while enabling servers * to gather necessary information dynamically. Servers can request structured * data from users with optional JSON schemas to validate responses. + * + *

+ * Per the 2025-11-25 spec, clients can declare support for specific elicitation + * modes: + *

    + *
  • {@code form} - In-band structured data collection with optional schema + * validation
  • + *
  • {@code url} - Out-of-band interaction via URL navigation
  • + *
+ * + *

+ * For backward compatibility, an empty elicitation object {@code {}} is + * equivalent to declaring support for form mode only. + * + * @param form support for in-band form-based elicitation + * @param url support for out-of-band URL-based elicitation */ @JsonInclude(JsonInclude.Include.NON_ABSENT) - public record Elicitation() { + public record Elicitation(@JsonProperty("form") Form form, @JsonProperty("url") Url url) { + + /** + * Marker record indicating support for form-based elicitation mode. + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + public record Form() { + } + + /** + * Marker record indicating support for URL-based elicitation mode. + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + public record Url() { + } + + /** + * Creates an Elicitation with default settings (backward compatible, produces + * empty JSON object). + */ + public Elicitation() { + this(null, null); + } } public static Builder builder() { @@ -441,11 +488,28 @@ public Builder sampling() { return this; } + /** + * Enables elicitation capability with default settings (backward compatible, + * produces empty JSON object). + * @return this builder + */ public Builder elicitation() { this.elicitation = new Elicitation(); return this; } + /** + * Enables elicitation capability with explicit form and/or url mode support. + * @param form whether to support form-based elicitation + * @param url whether to support URL-based elicitation + * @return this builder + */ + public Builder elicitation(boolean form, boolean url) { + this.elicitation = new Elicitation(form ? new Elicitation.Form() : null, + url ? new Elicitation.Url() : null); + return this; + } + public ClientCapabilities build() { return new ClientCapabilities(experimental, roots, sampling, elicitation); } @@ -607,7 +671,7 @@ public ServerCapabilities build() { public record Implementation( // @formatter:off @JsonProperty("name") String name, @JsonProperty("title") String title, - @JsonProperty("version") String version) implements BaseMetadata { // @formatter:on + @JsonProperty("version") String version) implements Identifier { // @formatter:on public Implementation(String name, String version) { this(name, null, version); @@ -651,7 +715,13 @@ public interface Annotated { @JsonIgnoreProperties(ignoreUnknown = true) public record Annotations( // @formatter:off @JsonProperty("audience") List audience, - @JsonProperty("priority") Double priority) { // @formatter:on + @JsonProperty("priority") Double priority, + @JsonProperty("lastModified") String lastModified + ) { // @formatter:on + + public Annotations(List audience, Double priority) { + this(audience, priority, null); + } } /** @@ -660,7 +730,9 @@ public record Annotations( // @formatter:off * interface is implemented by both {@link Resource} and {@link ResourceLink} to * provide a consistent way to access resource metadata. */ - public interface ResourceContent extends BaseMetadata { + public interface ResourceContent extends Identifier, Annotated, Meta { + + // name & title from Identifier String uri(); @@ -670,15 +742,15 @@ public interface ResourceContent extends BaseMetadata { Long size(); - Annotations annotations(); + // annotations from Annotated + // meta from Meta } /** - * Base interface for metadata with name (identifier) and title (display name) - * properties. + * Base interface with name (identifier) and title (display name) properties. */ - public interface BaseMetadata { + public interface Identifier { /** * Intended for programmatic or logical use, but used as a display name in past @@ -724,7 +796,7 @@ public record Resource( // @formatter:off @JsonProperty("mimeType") String mimeType, @JsonProperty("size") Long size, @JsonProperty("annotations") Annotations annotations, - @JsonProperty("_meta") Map meta) implements Annotated, ResourceContent { // @formatter:on + @JsonProperty("_meta") Map meta) implements ResourceContent { // @formatter:on /** * @deprecated Only exists for backwards-compatibility purposes. Use @@ -854,7 +926,7 @@ public record ResourceTemplate( // @formatter:off @JsonProperty("description") String description, @JsonProperty("mimeType") String mimeType, @JsonProperty("annotations") Annotations annotations, - @JsonProperty("_meta") Map meta) implements Annotated, BaseMetadata { // @formatter:on + @JsonProperty("_meta") Map meta) implements Annotated, Identifier, Meta { // @formatter:on public ResourceTemplate(String uriTemplate, String name, String title, String description, String mimeType, Annotations annotations) { @@ -865,6 +937,70 @@ public ResourceTemplate(String uriTemplate, String name, String description, Str Annotations annotations) { this(uriTemplate, name, null, description, mimeType, annotations); } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private String uriTemplate; + + private String name; + + private String title; + + private String description; + + private String mimeType; + + private Annotations annotations; + + private Map meta; + + public Builder uriTemplate(String uri) { + this.uriTemplate = uri; + return this; + } + + public Builder name(String name) { + this.name = name; + return this; + } + + public Builder title(String title) { + this.title = title; + return this; + } + + public Builder description(String description) { + this.description = description; + return this; + } + + public Builder mimeType(String mimeType) { + this.mimeType = mimeType; + return this; + } + + public Builder annotations(Annotations annotations) { + this.annotations = annotations; + return this; + } + + public Builder meta(Map meta) { + this.meta = meta; + return this; + } + + public ResourceTemplate build() { + Assert.hasText(uriTemplate, "uri must not be empty"); + Assert.hasText(name, "name must not be empty"); + + return new ResourceTemplate(uriTemplate, name, title, description, mimeType, annotations, meta); + } + + } } /** @@ -982,10 +1118,10 @@ public UnsubscribeRequest(String uri) { /** * The contents of a specific resource or sub-resource. */ - @JsonTypeInfo(use = JsonTypeInfo.Id.DEDUCTION, include = As.PROPERTY) - @JsonSubTypes({ @JsonSubTypes.Type(value = TextResourceContents.class, name = "text"), - @JsonSubTypes.Type(value = BlobResourceContents.class, name = "blob") }) - public sealed interface ResourceContents permits TextResourceContents, BlobResourceContents { + @JsonTypeInfo(use = JsonTypeInfo.Id.DEDUCTION) + @JsonSubTypes({ @JsonSubTypes.Type(value = TextResourceContents.class), + @JsonSubTypes.Type(value = BlobResourceContents.class) }) + public sealed interface ResourceContents extends Meta permits TextResourceContents, BlobResourceContents { /** * The URI of this resource. @@ -999,14 +1135,6 @@ public sealed interface ResourceContents permits TextResourceContents, BlobResou */ String mimeType(); - /** - * @see Specification - * for notes on _meta usage - * @return additional metadata related to this resource. - */ - Map meta(); - } /** @@ -1073,7 +1201,7 @@ public record Prompt( // @formatter:off @JsonProperty("title") String title, @JsonProperty("description") String description, @JsonProperty("arguments") List arguments, - @JsonProperty("_meta") Map meta) implements BaseMetadata { // @formatter:on + @JsonProperty("_meta") Map meta) implements Identifier { // @formatter:on public Prompt(String name, String description, List arguments) { this(name, null, description, arguments != null ? arguments : new ArrayList<>()); @@ -1098,7 +1226,7 @@ public record PromptArgument( // @formatter:off @JsonProperty("name") String name, @JsonProperty("title") String title, @JsonProperty("description") String description, - @JsonProperty("required") Boolean required) implements BaseMetadata { // @formatter:on + @JsonProperty("required") Boolean required) implements Identifier { // @formatter:on public PromptArgument(String name, String description, Boolean required) { this(name, null, description, required); @@ -1272,53 +1400,6 @@ public record Tool( // @formatter:off @JsonProperty("annotations") ToolAnnotations annotations, @JsonProperty("_meta") Map meta) { // @formatter:on - /** - * @deprecated Only exists for backwards-compatibility purposes. Use - * {@link Tool#builder()} instead. - */ - @Deprecated - public Tool(String name, String description, JsonSchema inputSchema, ToolAnnotations annotations) { - this(name, null, description, inputSchema, null, annotations, null); - } - - /** - * @deprecated Only exists for backwards-compatibility purposes. Use - * {@link Tool#builder()} instead. - */ - @Deprecated - public Tool(String name, String description, String inputSchema) { - this(name, null, description, parseSchema(inputSchema), null, null, null); - } - - /** - * @deprecated Only exists for backwards-compatibility purposes. Use - * {@link Tool#builder()} instead. - */ - @Deprecated - public Tool(String name, String description, String schema, ToolAnnotations annotations) { - this(name, null, description, parseSchema(schema), null, annotations, null); - } - - /** - * @deprecated Only exists for backwards-compatibility purposes. Use - * {@link Tool#builder()} instead. - */ - @Deprecated - public Tool(String name, String description, String inputSchema, String outputSchema, - ToolAnnotations annotations) { - this(name, null, description, parseSchema(inputSchema), schemaToMap(outputSchema), annotations, null); - } - - /** - * @deprecated Only exists for backwards-compatibility purposes. Use - * {@link Tool#builder()} instead. - */ - @Deprecated - public Tool(String name, String title, String description, String inputSchema, String outputSchema, - ToolAnnotations annotations) { - this(name, title, description, parseSchema(inputSchema), schemaToMap(outputSchema), annotations, null); - } - public static Builder builder() { return new Builder(); } @@ -1359,8 +1440,8 @@ public Builder inputSchema(JsonSchema inputSchema) { return this; } - public Builder inputSchema(String inputSchema) { - this.inputSchema = parseSchema(inputSchema); + public Builder inputSchema(McpJsonMapper jsonMapper, String inputSchema) { + this.inputSchema = parseSchema(jsonMapper, inputSchema); return this; } @@ -1369,8 +1450,8 @@ public Builder outputSchema(Map outputSchema) { return this; } - public Builder outputSchema(String outputSchema) { - this.outputSchema = schemaToMap(outputSchema); + public Builder outputSchema(McpJsonMapper jsonMapper, String outputSchema) { + this.outputSchema = schemaToMap(jsonMapper, outputSchema); return this; } @@ -1392,18 +1473,18 @@ public Tool build() { } } - private static Map schemaToMap(String schema) { + private static Map schemaToMap(McpJsonMapper jsonMapper, String schema) { try { - return OBJECT_MAPPER.readValue(schema, MAP_TYPE_REF); + return jsonMapper.readValue(schema, MAP_TYPE_REF); } catch (IOException e) { throw new IllegalArgumentException("Invalid schema: " + schema, e); } } - private static JsonSchema parseSchema(String schema) { + private static JsonSchema parseSchema(McpJsonMapper jsonMapper, String schema) { try { - return OBJECT_MAPPER.readValue(schema, JsonSchema.class); + return jsonMapper.readValue(schema, JsonSchema.class); } catch (IOException e) { throw new IllegalArgumentException("Invalid schema: " + schema, e); @@ -1427,17 +1508,17 @@ public record CallToolRequest( // @formatter:off @JsonProperty("arguments") Map arguments, @JsonProperty("_meta") Map meta) implements Request { // @formatter:on - public CallToolRequest(String name, String jsonArguments) { - this(name, parseJsonArguments(jsonArguments), null); + public CallToolRequest(McpJsonMapper jsonMapper, String name, String jsonArguments) { + this(name, parseJsonArguments(jsonMapper, jsonArguments), null); } public CallToolRequest(String name, Map arguments) { this(name, arguments, null); } - private static Map parseJsonArguments(String jsonArguments) { + private static Map parseJsonArguments(McpJsonMapper jsonMapper, String jsonArguments) { try { - return OBJECT_MAPPER.readValue(jsonArguments, MAP_TYPE_REF); + return jsonMapper.readValue(jsonArguments, MAP_TYPE_REF); } catch (IOException e) { throw new IllegalArgumentException("Invalid arguments: " + jsonArguments, e); @@ -1466,8 +1547,8 @@ public Builder arguments(Map arguments) { return this; } - public Builder arguments(String jsonArguments) { - this.arguments = parseJsonArguments(jsonArguments); + public Builder arguments(McpJsonMapper jsonMapper, String jsonArguments) { + this.arguments = parseJsonArguments(jsonMapper, jsonArguments); return this; } @@ -1476,7 +1557,7 @@ public Builder meta(Map meta) { return this; } - public Builder progressToken(String progressToken) { + public Builder progressToken(Object progressToken) { if (this.meta == null) { this.meta = new HashMap<>(); } @@ -1508,15 +1589,21 @@ public CallToolRequest build() { public record CallToolResult( // @formatter:off @JsonProperty("content") List content, @JsonProperty("isError") Boolean isError, - @JsonProperty("structuredContent") Map structuredContent, + @JsonProperty("structuredContent") Object structuredContent, @JsonProperty("_meta") Map meta) implements Result { // @formatter:on - // backwards compatibility constructor + /** + * @deprecated use the builder instead. + */ + @Deprecated public CallToolResult(List content, Boolean isError) { - this(content, isError, null, null); + this(content, isError, (Object) null, null); } - // backwards compatibility constructor + /** + * @deprecated use the builder instead. + */ + @Deprecated public CallToolResult(List content, Boolean isError, Map structuredContent) { this(content, isError, structuredContent, null); } @@ -1530,6 +1617,7 @@ public CallToolResult(List content, Boolean isError, Map structuredContent; + private Object structuredContent; private Map meta; @@ -1566,16 +1654,16 @@ public Builder content(List content) { return this; } - public Builder structuredContent(Map structuredContent) { + public Builder structuredContent(Object structuredContent) { Assert.notNull(structuredContent, "structuredContent must not be null"); this.structuredContent = structuredContent; return this; } - public Builder structuredContent(String structuredContent) { + public Builder structuredContent(McpJsonMapper jsonMapper, String structuredContent) { Assert.hasText(structuredContent, "structuredContent must not be empty"); try { - this.structuredContent = OBJECT_MAPPER.readValue(structuredContent, MAP_TYPE_REF); + this.structuredContent = jsonMapper.readValue(structuredContent, MAP_TYPE_REF); } catch (IOException e) { throw new IllegalArgumentException("Invalid structured content: " + structuredContent, e); @@ -1790,14 +1878,14 @@ public record CreateMessageRequest( // @formatter:off @JsonProperty("systemPrompt") String systemPrompt, @JsonProperty("includeContext") ContextInclusionStrategy includeContext, @JsonProperty("temperature") Double temperature, - @JsonProperty("maxTokens") int maxTokens, + @JsonProperty("maxTokens") Integer maxTokens, @JsonProperty("stopSequences") List stopSequences, @JsonProperty("metadata") Map metadata, @JsonProperty("_meta") Map meta) implements Request { // @formatter:on // backwards compatibility constructor public CreateMessageRequest(List messages, ModelPreferences modelPreferences, - String systemPrompt, ContextInclusionStrategy includeContext, Double temperature, int maxTokens, + String systemPrompt, ContextInclusionStrategy includeContext, Double temperature, Integer maxTokens, List stopSequences, Map metadata) { this(messages, modelPreferences, systemPrompt, includeContext, temperature, maxTokens, stopSequences, metadata, null); @@ -1827,7 +1915,7 @@ public static class Builder { private Double temperature; - private int maxTokens; + private Integer maxTokens; private List stopSequences; @@ -1880,7 +1968,7 @@ public Builder meta(Map meta) { return this; } - public Builder progressToken(String progressToken) { + public Builder progressToken(Object progressToken) { if (this.meta == null) { this.meta = new HashMap<>(); } @@ -2048,7 +2136,7 @@ public Builder meta(Map meta) { return this; } - public Builder progressToken(String progressToken) { + public Builder progressToken(Object progressToken) { if (this.meta == null) { this.meta = new HashMap<>(); } @@ -2185,13 +2273,13 @@ public record PaginatedResult(@JsonProperty("nextCursor") String nextCursor) { @JsonInclude(JsonInclude.Include.NON_ABSENT) @JsonIgnoreProperties(ignoreUnknown = true) public record ProgressNotification( // @formatter:off - @JsonProperty("progressToken") String progressToken, + @JsonProperty("progressToken") Object progressToken, @JsonProperty("progress") Double progress, @JsonProperty("total") Double total, @JsonProperty("message") String message, @JsonProperty("_meta") Map meta) implements Notification { // @formatter:on - public ProgressNotification(String progressToken, double progress, Double total, String message) { + public ProgressNotification(Object progressToken, double progress, Double total, String message) { this(progressToken, progress, total, message, null); } } @@ -2203,6 +2291,7 @@ public ProgressNotification(String progressToken, double progress, Double total, * @param uri The updated resource uri. * @param meta See specification for notes on _meta usage */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) @JsonIgnoreProperties(ignoreUnknown = true) public record ResourcesUpdatedNotification(// @formatter:off @JsonProperty("uri") String uri, @@ -2224,6 +2313,7 @@ public ResourcesUpdatedNotification(String uri) { * @param data JSON-serializable logging data. * @param meta See specification for notes on _meta usage */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) @JsonIgnoreProperties(ignoreUnknown = true) public record LoggingMessageNotification( // @formatter:off @JsonProperty("level") LoggingLevel level, @@ -2337,20 +2427,38 @@ public sealed interface CompleteReference permits PromptReference, ResourceRefer public record PromptReference( // @formatter:off @JsonProperty("type") String type, @JsonProperty("name") String name, - @JsonProperty("title") String title ) implements McpSchema.CompleteReference, BaseMetadata { // @formatter:on + @JsonProperty("title") String title ) implements McpSchema.CompleteReference, Identifier { // @formatter:on + + public static final String TYPE = "ref/prompt"; public PromptReference(String type, String name) { this(type, name, null); } public PromptReference(String name) { - this("ref/prompt", name, null); + this(TYPE, name, null); } @Override public String identifier() { return name(); } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null || getClass() != obj.getClass()) + return false; + PromptReference that = (PromptReference) obj; + return java.util.Objects.equals(identifier(), that.identifier()) + && java.util.Objects.equals(type(), that.type()); + } + + @Override + public int hashCode() { + return java.util.Objects.hash(identifier(), type()); + } } /** @@ -2365,8 +2473,10 @@ public record ResourceReference( // @formatter:off @JsonProperty("type") String type, @JsonProperty("uri") String uri) implements McpSchema.CompleteReference { // @formatter:on + public static final String TYPE = "ref/resource"; + public ResourceReference(String uri) { - this("ref/resource", uri); + this(TYPE, uri); } @Override @@ -2429,8 +2539,9 @@ public record CompleteContext(@JsonProperty("arguments") Map arg */ @JsonInclude(JsonInclude.Include.NON_ABSENT) @JsonIgnoreProperties(ignoreUnknown = true) - public record CompleteResult(@JsonProperty("completion") CompleteCompletion completion, - @JsonProperty("_meta") Map meta) implements Result { + public record CompleteResult(// @formatter:off + @JsonProperty("completion") CompleteCompletion completion, + @JsonProperty("_meta") Map meta) implements Result { // @formatter:on // backwards compatibility constructor public CompleteResult(CompleteCompletion completion) { @@ -2446,6 +2557,7 @@ public CompleteResult(CompleteCompletion completion) { * @param hasMore Indicates whether there are additional completion options beyond * those provided in the current response, even if the exact total is unknown */ + @JsonInclude(JsonInclude.Include.ALWAYS) public record CompleteCompletion( // @formatter:off @JsonProperty("values") List values, @JsonProperty("total") Integer total, @@ -2462,9 +2574,8 @@ public record CompleteCompletion( // @formatter:off @JsonSubTypes.Type(value = AudioContent.class, name = "audio"), @JsonSubTypes.Type(value = EmbeddedResource.class, name = "resource"), @JsonSubTypes.Type(value = ResourceLink.class, name = "resource_link") }) - public sealed interface Content permits TextContent, ImageContent, AudioContent, EmbeddedResource, ResourceLink { - - Map meta(); + public sealed interface Content extends Meta + permits TextContent, ImageContent, AudioContent, EmbeddedResource, ResourceLink { default String type() { if (this instanceof TextContent) { @@ -2689,29 +2800,7 @@ public record ResourceLink( // @formatter:off @JsonProperty("mimeType") String mimeType, @JsonProperty("size") Long size, @JsonProperty("annotations") Annotations annotations, - @JsonProperty("_meta") Map meta) implements Annotated, Content, ResourceContent { // @formatter:on - - /** - * @deprecated Only exists for backwards-compatibility purposes. Use - * {@link ResourceLink#ResourceLink(String, String, String, String, String, Long, Annotations)} - * instead. - */ - @Deprecated - public ResourceLink(String name, String title, String uri, String description, String mimeType, Long size, - Annotations annotations) { - this(name, title, uri, description, mimeType, size, annotations, null); - } - - /** - * @deprecated Only exists for backwards-compatibility purposes. Use - * {@link ResourceLink#ResourceLink(String, String, String, String, String, Long, Annotations)} - * instead. - */ - @Deprecated - public ResourceLink(String name, String uri, String description, String mimeType, Long size, - Annotations annotations) { - this(name, null, uri, description, mimeType, size, annotations); - } + @JsonProperty("_meta") Map meta) implements Content, ResourceContent { // @formatter:on public static Builder builder() { return new Builder(); diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java similarity index 81% rename from mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java rename to mcp-core/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java index 62985dc17..241f7d8b5 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java @@ -11,12 +11,12 @@ import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; -import com.fasterxml.jackson.core.type.TypeReference; +import io.modelcontextprotocol.common.McpTransportContext; import io.modelcontextprotocol.server.McpAsyncServerExchange; import io.modelcontextprotocol.server.McpInitRequestHandler; import io.modelcontextprotocol.server.McpNotificationHandler; import io.modelcontextprotocol.server.McpRequestHandler; -import io.modelcontextprotocol.server.McpTransportContext; +import io.modelcontextprotocol.json.TypeRef; import io.modelcontextprotocol.util.Assert; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -153,7 +153,7 @@ public boolean isNotificationForLevelAllowed(McpSchema.LoggingLevel loggingLevel } @Override - public Mono sendRequest(String method, Object requestParams, TypeReference typeRef) { + public Mono sendRequest(String method, Object requestParams, TypeRef typeRef) { String requestId = this.generateRequestId(); return Mono.create(sink -> { @@ -198,26 +198,38 @@ public Mono sendNotification(String method, Object params) { * @return a Mono that completes when the message is processed */ public Mono handle(McpSchema.JSONRPCMessage message) { - return Mono.defer(() -> { + return Mono.deferContextual(ctx -> { + McpTransportContext transportContext = ctx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY); + // TODO handle errors for communication to without initialization happening // first if (message instanceof McpSchema.JSONRPCResponse response) { - logger.debug("Received Response: {}", response); - var sink = pendingResponses.remove(response.id()); - if (sink == null) { - logger.warn("Unexpected response for unknown id {}", response.id()); + logger.debug("Received response: {}", response); + if (response.id() != null) { + var sink = pendingResponses.remove(response.id()); + if (sink == null) { + logger.warn("Unexpected response for unknown id {}", response.id()); + } + else { + sink.success(response); + } } else { - sink.success(response); + logger.error("Discarded MCP request response without session id. " + + "This is an indication of a bug in the request sender code that can lead to memory " + + "leaks as pending requests will never be completed."); } return Mono.empty(); } else if (message instanceof McpSchema.JSONRPCRequest request) { logger.debug("Received request: {}", request); - return handleIncomingRequest(request).onErrorResume(error -> { + return handleIncomingRequest(request, transportContext).onErrorResume(error -> { + McpSchema.JSONRPCResponse.JSONRPCError jsonRpcError = (error instanceof McpError mcpError + && mcpError.getJsonRpcError() != null) ? mcpError.getJsonRpcError() + : new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, + error.getMessage(), McpError.aggregateExceptionMessages(error)); var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, - new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, - error.getMessage(), null)); + jsonRpcError); // TODO: Should the error go to SSE or back as POST return? return this.transport.sendMessage(errorResponse).then(Mono.empty()); }).flatMap(this.transport::sendMessage); @@ -227,7 +239,7 @@ else if (message instanceof McpSchema.JSONRPCNotification notification) { // happening first logger.debug("Received notification: {}", notification); // TODO: in case of error, should the POST request be signalled? - return handleIncomingNotification(notification) + return handleIncomingNotification(notification, transportContext) .doOnError(error -> logger.error("Error handling notification: {}", error.getMessage())); } else { @@ -240,15 +252,17 @@ else if (message instanceof McpSchema.JSONRPCNotification notification) { /** * Handles an incoming JSON-RPC request by routing it to the appropriate handler. * @param request The incoming JSON-RPC request + * @param transportContext * @return A Mono containing the JSON-RPC response */ - private Mono handleIncomingRequest(McpSchema.JSONRPCRequest request) { + private Mono handleIncomingRequest(McpSchema.JSONRPCRequest request, + McpTransportContext transportContext) { return Mono.defer(() -> { Mono resultMono; if (McpSchema.METHOD_INITIALIZE.equals(request.method())) { // TODO handle situation where already initialized! McpSchema.InitializeRequest initializeRequest = transport.unmarshalFrom(request.params(), - new TypeReference() { + new TypeRef() { }); this.state.lazySet(STATE_INITIALIZING); @@ -266,30 +280,38 @@ private Mono handleIncomingRequest(McpSchema.JSONRPCR error.message(), error.data()))); } - resultMono = this.exchangeSink.asMono().flatMap(exchange -> handler.handle(exchange, request.params())); + resultMono = this.exchangeSink.asMono() + .flatMap(exchange -> handler.handle(copyExchange(exchange, transportContext), request.params())); } return resultMono .map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), result, null)) - .onErrorResume(error -> Mono.just(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), - null, new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, - error.getMessage(), null)))); // TODO: add error message - // through the data field + .onErrorResume(error -> { + McpSchema.JSONRPCResponse.JSONRPCError jsonRpcError = (error instanceof McpError mcpError + && mcpError.getJsonRpcError() != null) ? mcpError.getJsonRpcError() + // TODO: add error message through the data field + : new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, + error.getMessage(), McpError.aggregateExceptionMessages(error)); + return Mono.just( + new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, jsonRpcError)); + }); }); } /** * Handles an incoming JSON-RPC notification by routing it to the appropriate handler. * @param notification The incoming JSON-RPC notification + * @param transportContext * @return A Mono that completes when the notification is processed */ - private Mono handleIncomingNotification(McpSchema.JSONRPCNotification notification) { + private Mono handleIncomingNotification(McpSchema.JSONRPCNotification notification, + McpTransportContext transportContext) { return Mono.defer(() -> { if (McpSchema.METHOD_NOTIFICATION_INITIALIZED.equals(notification.method())) { this.state.lazySet(STATE_INITIALIZED); // FIXME: The session ID passed here is not the same as the one in the // legacy SSE transport. exchangeSink.tryEmitValue(new McpAsyncServerExchange(this.id, this, clientCapabilities.get(), - clientInfo.get(), McpTransportContext.EMPTY)); + clientInfo.get(), transportContext)); } var handler = notificationHandlers.get(notification.method()); @@ -297,10 +319,23 @@ private Mono handleIncomingNotification(McpSchema.JSONRPCNotification noti logger.warn("No handler registered for notification method: {}", notification); return Mono.empty(); } - return this.exchangeSink.asMono().flatMap(exchange -> handler.handle(exchange, notification.params())); + return this.exchangeSink.asMono() + .flatMap(exchange -> handler.handle(copyExchange(exchange, transportContext), notification.params())); }); } + /** + * This legacy implementation assumes an exchange is established upon the + * initialization phase see: exchangeSink.tryEmitValue(...), which creates a cached + * immutable exchange. Here, we create a new exchange and copy over everything from + * that cached exchange, and use it for a single HTTP request, with the transport + * context passed in. + */ + private McpAsyncServerExchange copyExchange(McpAsyncServerExchange exchange, McpTransportContext transportContext) { + return new McpAsyncServerExchange(exchange.sessionId(), this, exchange.getClientCapabilities(), + exchange.getClientInfo(), transportContext); + } + record MethodNotFoundError(String method, String message, Object data) { } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransport.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpServerTransport.java similarity index 100% rename from mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransport.java rename to mcp-core/src/main/java/io/modelcontextprotocol/spec/McpServerTransport.java diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java similarity index 100% rename from mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java rename to mcp-core/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProviderBase.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProviderBase.java similarity index 100% rename from mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProviderBase.java rename to mcp-core/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProviderBase.java diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSession.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpSession.java similarity index 97% rename from mcp/src/main/java/io/modelcontextprotocol/spec/McpSession.java rename to mcp-core/src/main/java/io/modelcontextprotocol/spec/McpSession.java index 3473a4da8..767ed673e 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSession.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpSession.java @@ -4,7 +4,7 @@ package io.modelcontextprotocol.spec; -import com.fasterxml.jackson.core.type.TypeReference; +import io.modelcontextprotocol.json.TypeRef; import reactor.core.publisher.Mono; /** @@ -37,7 +37,7 @@ public interface McpSession { * @param typeRef the TypeReference describing the expected response type * @return a Mono that will emit the response when received */ - Mono sendRequest(String method, Object requestParams, TypeReference typeRef); + Mono sendRequest(String method, Object requestParams, TypeRef typeRef); /** * Sends a notification to the model client or server without parameters. diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpStatelessServerTransport.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpStatelessServerTransport.java similarity index 90% rename from mcp/src/main/java/io/modelcontextprotocol/spec/McpStatelessServerTransport.java rename to mcp-core/src/main/java/io/modelcontextprotocol/spec/McpStatelessServerTransport.java index c1234b130..d1c2e5206 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpStatelessServerTransport.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpStatelessServerTransport.java @@ -29,7 +29,7 @@ default void close() { Mono closeGracefully(); default List protocolVersions() { - return List.of(ProtocolVersions.MCP_2025_03_26); + return List.of(ProtocolVersions.MCP_2025_03_26, ProtocolVersions.MCP_2025_06_18); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java similarity index 90% rename from mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java rename to mcp-core/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java index ef7967c1e..95f8959f5 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java @@ -15,12 +15,13 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import com.fasterxml.jackson.core.type.TypeReference; +import io.modelcontextprotocol.json.TypeRef; +import io.modelcontextprotocol.common.McpTransportContext; import io.modelcontextprotocol.server.McpAsyncServerExchange; import io.modelcontextprotocol.server.McpNotificationHandler; import io.modelcontextprotocol.server.McpRequestHandler; -import io.modelcontextprotocol.server.McpTransportContext; +import io.modelcontextprotocol.spec.McpSchema.ErrorCodes; import io.modelcontextprotocol.util.Assert; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -33,6 +34,7 @@ * capability without the insight into the transport-specific details of HTTP handling. * * @author Dariusz Jędrzejczyk + * @author Yanming Zhou */ public class McpStreamableServerSession implements McpLoggableSession { @@ -108,7 +110,7 @@ private String generateRequestId() { } @Override - public Mono sendRequest(String method, Object requestParams, TypeReference typeRef) { + public Mono sendRequest(String method, Object requestParams, TypeRef typeRef) { return Mono.defer(() -> { McpLoggableSession listeningStream = this.listeningStreamRef.get(); return listeningStream.sendRequest(method, requestParams, typeRef); @@ -177,9 +179,13 @@ public Mono responseStream(McpSchema.JSONRPCRequest jsonrpcRequest, McpStr .map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, jsonrpcRequest.id(), result, null)) .onErrorResume(e -> { + McpSchema.JSONRPCResponse.JSONRPCError jsonRpcError = (e instanceof McpError mcpError + && mcpError.getJsonRpcError() != null) ? mcpError.getJsonRpcError() + : new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, + e.getMessage(), McpError.aggregateExceptionMessages(e)); + var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, jsonrpcRequest.id(), - null, new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, - e.getMessage(), null)); + null, jsonRpcError); return Mono.just(errorResponse); }) .flatMap(transport::sendMessage) @@ -214,19 +220,30 @@ public Mono accept(McpSchema.JSONRPCNotification notification) { */ public Mono accept(McpSchema.JSONRPCResponse response) { return Mono.defer(() -> { - var stream = this.requestIdToStream.get(response.id()); - if (stream == null) { - return Mono.error(new McpError("Unexpected response for unknown id " + response.id())); // TODO - // JSONize - } - // TODO: encapsulate this inside the stream itself - var sink = stream.pendingResponses.remove(response.id()); - if (sink == null) { - return Mono.error(new McpError("Unexpected response for unknown id " + response.id())); // TODO - // JSONize + logger.debug("Received response: {}", response); + + if (response.id() != null) { + var stream = this.requestIdToStream.get(response.id()); + if (stream == null) { + return Mono.error(McpError.builder(ErrorCodes.INTERNAL_ERROR) + .message("Unexpected response for unknown id " + response.id()) + .build()); + } + // TODO: encapsulate this inside the stream itself + var sink = stream.pendingResponses.remove(response.id()); + if (sink == null) { + return Mono.error(McpError.builder(ErrorCodes.INTERNAL_ERROR) + .message("Unexpected response for unknown id " + response.id()) + .build()); + } + else { + sink.success(response); + } } else { - sink.success(response); + logger.error("Discarded MCP request response without session id. " + + "This is an indication of a bug in the request sender code that can lead to memory " + + "leaks as pending requests will never be completed."); } return Mono.empty(); }); @@ -334,7 +351,7 @@ public boolean isNotificationForLevelAllowed(McpSchema.LoggingLevel loggingLevel } @Override - public Mono sendRequest(String method, Object requestParams, TypeReference typeRef) { + public Mono sendRequest(String method, Object requestParams, TypeRef typeRef) { String requestId = McpStreamableServerSession.this.generateRequestId(); McpStreamableServerSession.this.requestIdToStream.put(requestId, this); diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerTransport.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerTransport.java similarity index 100% rename from mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerTransport.java rename to mcp-core/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerTransport.java diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerTransportProvider.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerTransportProvider.java similarity index 100% rename from mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerTransportProvider.java rename to mcp-core/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerTransportProvider.java diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransport.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransport.java similarity index 95% rename from mcp/src/main/java/io/modelcontextprotocol/spec/McpTransport.java rename to mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransport.java index 1922548a6..0a732bab6 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransport.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransport.java @@ -6,8 +6,8 @@ import java.util.List; -import com.fasterxml.jackson.core.type.TypeReference; import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; +import io.modelcontextprotocol.json.TypeRef; import reactor.core.publisher.Mono; /** @@ -77,7 +77,7 @@ default void close() { * @param typeRef the type reference for the object to unmarshal * @return the unmarshalled object */ - T unmarshalFrom(Object data, TypeReference typeRef); + T unmarshalFrom(Object data, TypeRef typeRef); default List protocolVersions() { return List.of(ProtocolVersions.MCP_2024_11_05); diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportException.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransportException.java similarity index 100% rename from mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportException.java rename to mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransportException.java diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportSession.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransportSession.java similarity index 100% rename from mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportSession.java rename to mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransportSession.java index 716ff0d16..68f0fc5bb 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportSession.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransportSession.java @@ -4,10 +4,10 @@ package io.modelcontextprotocol.spec; -import org.reactivestreams.Publisher; - import java.util.Optional; +import org.reactivestreams.Publisher; + /** * An abstraction of the session as perceived from the MCP transport layer. Not to be * confused with the {@link McpSession} type that operates at the level of the JSON-RPC diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransportSessionClosedException.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransportSessionClosedException.java new file mode 100644 index 000000000..60e2850b9 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransportSessionClosedException.java @@ -0,0 +1,23 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package io.modelcontextprotocol.spec; + +import reactor.util.annotation.Nullable; + +/** + * Exception thrown when trying to use an {@link McpTransportSession} that has been + * closed. + * + * @see ClosedMcpTransportSession + * @author Daniel Garnier-Moiroux + */ +public class McpTransportSessionClosedException extends RuntimeException { + + public McpTransportSessionClosedException(@Nullable String sessionId) { + super(sessionId != null ? "MCP session with ID %s has been closed".formatted(sessionId) + : "MCP session has been closed"); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportSessionNotFoundException.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransportSessionNotFoundException.java similarity index 100% rename from mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportSessionNotFoundException.java rename to mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransportSessionNotFoundException.java diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportStream.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransportStream.java similarity index 100% rename from mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportStream.java rename to mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransportStream.java diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/MissingMcpTransportSession.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/MissingMcpTransportSession.java similarity index 95% rename from mcp/src/main/java/io/modelcontextprotocol/spec/MissingMcpTransportSession.java rename to mcp-core/src/main/java/io/modelcontextprotocol/spec/MissingMcpTransportSession.java index aa33a8167..0bf70d5b8 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/MissingMcpTransportSession.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/MissingMcpTransportSession.java @@ -4,7 +4,7 @@ package io.modelcontextprotocol.spec; -import com.fasterxml.jackson.core.type.TypeReference; +import io.modelcontextprotocol.json.TypeRef; import io.modelcontextprotocol.util.Assert; import reactor.core.publisher.Mono; @@ -31,7 +31,7 @@ public MissingMcpTransportSession(String sessionId) { } @Override - public Mono sendRequest(String method, Object requestParams, TypeReference typeRef) { + public Mono sendRequest(String method, Object requestParams, TypeRef typeRef) { return Mono.error(new IllegalStateException("Stream unavailable for session " + this.sessionId)); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/ProtocolVersions.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/ProtocolVersions.java similarity index 77% rename from mcp/src/main/java/io/modelcontextprotocol/spec/ProtocolVersions.java rename to mcp-core/src/main/java/io/modelcontextprotocol/spec/ProtocolVersions.java index d8cb913a5..d3d34db62 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/ProtocolVersions.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/ProtocolVersions.java @@ -20,4 +20,10 @@ public interface ProtocolVersions { */ String MCP_2025_06_18 = "2025-06-18"; + /** + * MCP protocol version for 2025-11-25. + * https://modelcontextprotocol.io/specification/2025-11-25 + */ + String MCP_2025_11_25 = "2025-11-25"; + } diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/Assert.java b/mcp-core/src/main/java/io/modelcontextprotocol/util/Assert.java similarity index 100% rename from mcp/src/main/java/io/modelcontextprotocol/util/Assert.java rename to mcp-core/src/main/java/io/modelcontextprotocol/util/Assert.java diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/DefaultMcpUriTemplateManager.java b/mcp-core/src/main/java/io/modelcontextprotocol/util/DefaultMcpUriTemplateManager.java similarity index 81% rename from mcp/src/main/java/io/modelcontextprotocol/util/DefaultMcpUriTemplateManager.java rename to mcp-core/src/main/java/io/modelcontextprotocol/util/DefaultMcpUriTemplateManager.java index b2e9a5285..c3b922edf 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/util/DefaultMcpUriTemplateManager.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/util/DefaultMcpUriTemplateManager.java @@ -33,9 +33,7 @@ public class DefaultMcpUriTemplateManager implements McpUriTemplateManager { * @param uriTemplate The URI template to be used for variable extraction */ public DefaultMcpUriTemplateManager(String uriTemplate) { - if (uriTemplate == null || uriTemplate.isEmpty()) { - throw new IllegalArgumentException("URI template must not be null or empty"); - } + Assert.hasText(uriTemplate, "URI template must not be null or empty"); this.uriTemplate = uriTemplate; } @@ -48,10 +46,6 @@ public DefaultMcpUriTemplateManager(String uriTemplate) { */ @Override public List getVariableNames() { - if (uriTemplate == null || uriTemplate.isEmpty()) { - return List.of(); - } - List variables = new ArrayList<>(); Matcher matcher = URI_VARIABLE_PATTERN.matcher(this.uriTemplate); @@ -81,7 +75,7 @@ public Map extractVariableValues(String requestUri) { Map variableValues = new HashMap<>(); List uriVariables = this.getVariableNames(); - if (requestUri == null || uriVariables.isEmpty()) { + if (!Utils.hasText(requestUri) || uriVariables.isEmpty()) { return variableValues; } @@ -147,12 +141,30 @@ public boolean matches(String uri) { return uri.equals(this.uriTemplate); } - // Convert the pattern to a regex - String regex = this.uriTemplate.replaceAll("\\{[^/]+?\\}", "([^/]+?)"); - regex = regex.replace("/", "\\/"); + // Convert the URI template into a robust regex pattern that escapes special + // characters like '?'. + StringBuilder patternBuilder = new StringBuilder("^"); + Matcher variableMatcher = URI_VARIABLE_PATTERN.matcher(this.uriTemplate); + int lastEnd = 0; + + while (variableMatcher.find()) { + // Append the literal part of the template, safely quoted + String textBefore = this.uriTemplate.substring(lastEnd, variableMatcher.start()); + patternBuilder.append(Pattern.quote(textBefore)); + // Append a capturing group for the variable itself + patternBuilder.append("([^/]+?)"); + lastEnd = variableMatcher.end(); + } + + // Append any remaining literal text after the last variable + if (lastEnd < this.uriTemplate.length()) { + patternBuilder.append(Pattern.quote(this.uriTemplate.substring(lastEnd))); + } + + patternBuilder.append("$"); // Check if the URI matches the regex - return Pattern.compile(regex).matcher(uri).matches(); + return Pattern.compile(patternBuilder.toString()).matcher(uri).matches(); } @Override diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/DeafaultMcpUriTemplateManagerFactory.java b/mcp-core/src/main/java/io/modelcontextprotocol/util/DefaultMcpUriTemplateManagerFactory.java similarity index 86% rename from mcp/src/main/java/io/modelcontextprotocol/util/DeafaultMcpUriTemplateManagerFactory.java rename to mcp-core/src/main/java/io/modelcontextprotocol/util/DefaultMcpUriTemplateManagerFactory.java index 44ea31690..fd1a3bd71 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/util/DeafaultMcpUriTemplateManagerFactory.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/util/DefaultMcpUriTemplateManagerFactory.java @@ -7,7 +7,7 @@ /** * @author Christian Tzolov */ -public class DeafaultMcpUriTemplateManagerFactory implements McpUriTemplateManagerFactory { +public class DefaultMcpUriTemplateManagerFactory implements McpUriTemplateManagerFactory { /** * Creates a new instance of {@link McpUriTemplateManager} with the specified URI diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/KeepAliveScheduler.java b/mcp-core/src/main/java/io/modelcontextprotocol/util/KeepAliveScheduler.java similarity index 97% rename from mcp/src/main/java/io/modelcontextprotocol/util/KeepAliveScheduler.java rename to mcp-core/src/main/java/io/modelcontextprotocol/util/KeepAliveScheduler.java index 9d411cd41..6d53ed516 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/util/KeepAliveScheduler.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/util/KeepAliveScheduler.java @@ -11,7 +11,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import com.fasterxml.jackson.core.type.TypeReference; +import io.modelcontextprotocol.json.TypeRef; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSession; @@ -33,7 +33,7 @@ public class KeepAliveScheduler { private static final Logger logger = LoggerFactory.getLogger(KeepAliveScheduler.class); - private static final TypeReference OBJECT_TYPE_REF = new TypeReference<>() { + private static final TypeRef OBJECT_TYPE_REF = new TypeRef<>() { }; /** Initial delay before the first keepAlive call */ diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/McpUriTemplateManager.java b/mcp-core/src/main/java/io/modelcontextprotocol/util/McpUriTemplateManager.java similarity index 100% rename from mcp/src/main/java/io/modelcontextprotocol/util/McpUriTemplateManager.java rename to mcp-core/src/main/java/io/modelcontextprotocol/util/McpUriTemplateManager.java diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/McpUriTemplateManagerFactory.java b/mcp-core/src/main/java/io/modelcontextprotocol/util/McpUriTemplateManagerFactory.java similarity index 100% rename from mcp/src/main/java/io/modelcontextprotocol/util/McpUriTemplateManagerFactory.java rename to mcp-core/src/main/java/io/modelcontextprotocol/util/McpUriTemplateManagerFactory.java diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java b/mcp-core/src/main/java/io/modelcontextprotocol/util/Utils.java similarity index 100% rename from mcp/src/main/java/io/modelcontextprotocol/util/Utils.java rename to mcp-core/src/main/java/io/modelcontextprotocol/util/Utils.java index 039b0d68e..cd420100c 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/util/Utils.java @@ -4,12 +4,12 @@ package io.modelcontextprotocol.util; -import reactor.util.annotation.Nullable; - import java.net.URI; import java.util.Collection; import java.util.Map; +import reactor.util.annotation.Nullable; + /** * Miscellaneous utility methods. * diff --git a/mcp/src/test/java/io/modelcontextprotocol/McpUriTemplateManagerTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/McpUriTemplateManagerTests.java similarity index 86% rename from mcp/src/test/java/io/modelcontextprotocol/McpUriTemplateManagerTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/McpUriTemplateManagerTests.java index 6f041daa6..8f68f0d6e 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/McpUriTemplateManagerTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/McpUriTemplateManagerTests.java @@ -12,7 +12,7 @@ import java.util.List; import java.util.Map; -import io.modelcontextprotocol.util.DeafaultMcpUriTemplateManagerFactory; +import io.modelcontextprotocol.util.DefaultMcpUriTemplateManagerFactory; import io.modelcontextprotocol.util.McpUriTemplateManager; import io.modelcontextprotocol.util.McpUriTemplateManagerFactory; import org.junit.jupiter.api.BeforeEach; @@ -29,7 +29,7 @@ public class McpUriTemplateManagerTests { @BeforeEach void setUp() { - this.uriTemplateFactory = new DeafaultMcpUriTemplateManagerFactory(); + this.uriTemplateFactory = new DefaultMcpUriTemplateManagerFactory(); } @Test @@ -94,4 +94,13 @@ void shouldMatchUriAgainstTemplatePattern() { assertFalse(uriTemplateManager.matches("/api/users/123/comments/456")); } + @Test + void shouldMatchUriWithQueryParameters() { + String templateWithQuery = "file://name/search?={search}"; + var uriTemplateManager = this.uriTemplateFactory.create(templateWithQuery); + + assertTrue(uriTemplateManager.matches("file://name/search?=abcd"), + "Should correctly match a URI containing query parameters."); + } + } diff --git a/mcp/src/test/java/io/modelcontextprotocol/MockMcpClientTransport.java b/mcp-core/src/test/java/io/modelcontextprotocol/MockMcpClientTransport.java similarity index 92% rename from mcp/src/test/java/io/modelcontextprotocol/MockMcpClientTransport.java rename to mcp-core/src/test/java/io/modelcontextprotocol/MockMcpClientTransport.java index b1113a6d0..9854de210 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/MockMcpClientTransport.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/MockMcpClientTransport.java @@ -9,8 +9,8 @@ import java.util.function.BiConsumer; import java.util.function.Function; -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.TypeRef; import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.JSONRPCNotification; @@ -99,8 +99,8 @@ public Mono closeGracefully() { } @Override - public T unmarshalFrom(Object data, TypeReference typeRef) { - return new ObjectMapper().convertValue(data, typeRef); + public T unmarshalFrom(Object data, TypeRef typeRef) { + return McpJsonMapper.getDefault().convertValue(data, typeRef); } } diff --git a/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransport.java b/mcp-core/src/test/java/io/modelcontextprotocol/MockMcpServerTransport.java similarity index 80% rename from mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransport.java rename to mcp-core/src/test/java/io/modelcontextprotocol/MockMcpServerTransport.java index 4be680e11..f3d6b77a7 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransport.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/MockMcpServerTransport.java @@ -8,8 +8,8 @@ import java.util.List; import java.util.function.BiConsumer; -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.TypeRef; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.JSONRPCNotification; import io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest; @@ -53,14 +53,22 @@ public McpSchema.JSONRPCMessage getLastSentMessage() { return !sent.isEmpty() ? sent.get(sent.size() - 1) : null; } + public void clearSentMessages() { + sent.clear(); + } + + public List getAllSentMessages() { + return new ArrayList<>(sent); + } + @Override public Mono closeGracefully() { return Mono.empty(); } @Override - public T unmarshalFrom(Object data, TypeReference typeRef) { - return new ObjectMapper().convertValue(data, typeRef); + public T unmarshalFrom(Object data, TypeRef typeRef) { + return McpJsonMapper.getDefault().convertValue(data, typeRef); } } diff --git a/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java b/mcp-core/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java similarity index 100% rename from mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java rename to mcp-core/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java similarity index 92% rename from mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java index ec23e21dc..183b8a365 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java @@ -10,6 +10,7 @@ import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpTransport; +import io.modelcontextprotocol.spec.McpTransportSessionClosedException; import org.junit.jupiter.api.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -49,7 +50,7 @@ public abstract class AbstractMcpAsyncClientResiliencyTests { // Uses the https://github.com/tzolov/mcp-everything-server-docker-image @SuppressWarnings("resource") - static GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") + static GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v3") .withCommand("node dist/index.js streamableHttp") .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) .withNetwork(network) @@ -135,10 +136,13 @@ McpAsyncClient client(McpClientTransport transport, Function client = new AtomicReference<>(); assertThatCode(() -> { + // Do not advertise roots. Otherwise, the server will list roots during + // initialization. The client responds asynchronously, and there might be a + // rest condition in tests where we disconnect right after initialization. McpClient.AsyncSpec builder = McpClient.async(transport) .requestTimeout(getRequestTimeout()) .initializationTimeout(getInitializationTimeout()) - .capabilities(McpSchema.ClientCapabilities.builder().roots(true).build()); + .capabilities(McpSchema.ClientCapabilities.builder().build()); builder = customizer.apply(builder); client.set(builder.build()); }).doesNotThrowAnyException(); @@ -218,9 +222,10 @@ void testSessionClose() { // In case of Streamable HTTP this call should issue a HTTP DELETE request // invalidating the session StepVerifier.create(mcpAsyncClient.closeGracefully()).expectComplete().verify(); - // The next use should immediately re-initialize with no issue and send the - // request without any broken connections. - StepVerifier.create(mcpAsyncClient.ping()).expectNextCount(1).verifyComplete(); + // The next tries to use the closed session and fails + StepVerifier.create(mcpAsyncClient.ping()) + .expectErrorMatches(err -> err.getCause() instanceof McpTransportSessionClosedException) + .verify(); }); } diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java similarity index 98% rename from mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java index 3626d8ca0..57a223ea2 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java @@ -4,6 +4,7 @@ package io.modelcontextprotocol.client; +import static io.modelcontextprotocol.util.McpJsonMapperUtils.JSON_MAPPER; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -22,8 +23,6 @@ import java.util.function.Consumer; import java.util.function.Function; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; @@ -68,12 +67,6 @@ public abstract class AbstractMcpAsyncClientTests { abstract protected McpClientTransport createMcpTransport(); - protected void onStart() { - } - - protected void onClose() { - } - protected Duration getRequestTimeout() { return Duration.ofSeconds(14); } @@ -118,16 +111,6 @@ void withClient(McpClientTransport transport, Function void verifyNotificationSucceedsWithImplicitInitialization(Function> operation, String action) { withClient(createMcpTransport(), mcpAsyncClient -> { @@ -193,7 +176,12 @@ void testListAllToolsReturnsImmutableList() { .consumeNextWith(result -> { assertThat(result.tools()).isNotNull(); // Verify that the returned list is immutable - assertThatThrownBy(() -> result.tools().add(new Tool("test", "test", "{\"type\":\"object\"}"))) + assertThatThrownBy(() -> result.tools() + .add(Tool.builder() + .name("test") + .title("test") + .inputSchema(JSON_MAPPER, "{\"type\":\"object\"}") + .build())) .isInstanceOf(UnsupportedOperationException.class); }) .verifyComplete(); @@ -685,7 +673,7 @@ void testInitializeWithElicitationCapability() { @Test void testInitializeWithAllCapabilities() { var capabilities = ClientCapabilities.builder() - .experimental(Map.of("feature", "test")) + .experimental(Map.of("feature", Map.of("featureFlag", true))) .roots(true) .sampling() .build(); diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java similarity index 98% rename from mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java index c74255060..7ce12772c 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java @@ -22,8 +22,6 @@ import java.util.function.Consumer; import java.util.function.Function; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; @@ -67,12 +65,6 @@ public abstract class AbstractMcpSyncClientTests { abstract protected McpClientTransport createMcpTransport(); - protected void onStart() { - } - - protected void onClose() { - } - protected Duration getRequestTimeout() { return Duration.ofSeconds(14); } @@ -115,17 +107,6 @@ void withClient(McpClientTransport transport, Function void verifyNotificationSucceedsWithImplicitInitialization(Consumer operation, String action) { @@ -555,11 +536,13 @@ void testNotificationHandlers() { AtomicBoolean toolsNotificationReceived = new AtomicBoolean(false); AtomicBoolean resourcesNotificationReceived = new AtomicBoolean(false); AtomicBoolean promptsNotificationReceived = new AtomicBoolean(false); + AtomicBoolean resourcesUpdatedNotificationReceived = new AtomicBoolean(false); withClient(createMcpTransport(), builder -> builder.toolsChangeConsumer(tools -> toolsNotificationReceived.set(true)) .resourcesChangeConsumer(resources -> resourcesNotificationReceived.set(true)) - .promptsChangeConsumer(prompts -> promptsNotificationReceived.set(true)), + .promptsChangeConsumer(prompts -> promptsNotificationReceived.set(true)) + .resourcesUpdateConsumer(resources -> resourcesUpdatedNotificationReceived.set(true)), client -> { assertThatCode(() -> { diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpAsyncClientResiliencyTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpAsyncClientResiliencyTests.java similarity index 100% rename from mcp/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpAsyncClientResiliencyTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpAsyncClientResiliencyTests.java diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpAsyncClientTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpAsyncClientTests.java similarity index 77% rename from mcp/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpAsyncClientTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpAsyncClientTests.java index aef2ab8dd..c4157bc37 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpAsyncClientTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpAsyncClientTests.java @@ -4,21 +4,22 @@ package io.modelcontextprotocol.client; +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Timeout; import org.testcontainers.containers.GenericContainer; import org.testcontainers.containers.wait.strategy.Wait; -import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; -import io.modelcontextprotocol.spec.McpClientTransport; - @Timeout(15) public class HttpClientStreamableHttpAsyncClientTests extends AbstractMcpAsyncClientTests { - private String host = "http://localhost:3001"; + private static String host = "http://localhost:3001"; // Uses the https://github.com/tzolov/mcp-everything-server-docker-image @SuppressWarnings("resource") - GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") + static GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v3") .withCommand("node dist/index.js streamableHttp") .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) .withExposedPorts(3001) @@ -26,19 +27,18 @@ public class HttpClientStreamableHttpAsyncClientTests extends AbstractMcpAsyncCl @Override protected McpClientTransport createMcpTransport() { - return HttpClientStreamableHttpTransport.builder(host).build(); } - @Override - protected void onStart() { + @BeforeAll + static void startContainer() { container.start(); int port = container.getMappedPort(3001); host = "http://" + container.getHost() + ":" + port; } - @Override - public void onClose() { + @AfterAll + static void stopContainer() { container.stop(); } diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpSyncClientTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpSyncClientTests.java new file mode 100644 index 000000000..d59ae35b4 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpSyncClientTests.java @@ -0,0 +1,73 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.client; + +import java.net.URI; +import java.util.Map; + +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.wait.strategy.Wait; + +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.client.transport.customizer.McpSyncHttpClientRequestCustomizer; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.spec.McpClientTransport; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +@Timeout(15) +public class HttpClientStreamableHttpSyncClientTests extends AbstractMcpSyncClientTests { + + static String host = "http://localhost:3001"; + + // Uses the https://github.com/tzolov/mcp-everything-server-docker-image + @SuppressWarnings("resource") + static GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v3") + .withCommand("node dist/index.js streamableHttp") + .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) + .withExposedPorts(3001) + .waitingFor(Wait.forHttp("/").forStatusCode(404)); + + private final McpSyncHttpClientRequestCustomizer requestCustomizer = mock(McpSyncHttpClientRequestCustomizer.class); + + @Override + protected McpClientTransport createMcpTransport() { + return HttpClientStreamableHttpTransport.builder(host).httpRequestCustomizer(requestCustomizer).build(); + } + + @BeforeAll + static void startContainer() { + container.start(); + int port = container.getMappedPort(3001); + host = "http://" + container.getHost() + ":" + port; + } + + @AfterAll + static void stopContainer() { + container.stop(); + } + + @Test + void customizesRequests() { + var mcpTransportContext = McpTransportContext.create(Map.of("some-key", "some-value")); + withClient(createMcpTransport(), syncSpec -> syncSpec.transportContextProvider(() -> mcpTransportContext), + mcpSyncClient -> { + mcpSyncClient.initialize(); + + verify(requestCustomizer, atLeastOnce()).customize(any(), eq("POST"), eq(URI.create(host + "/mcp")), + eq("{\"jsonrpc\":\"2.0\",\"method\":\"notifications/initialized\"}"), + eq(mcpTransportContext)); + }); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientLostConnectionTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientLostConnectionTests.java similarity index 91% rename from mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientLostConnectionTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientLostConnectionTests.java index 0a72b785d..30e7fe913 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientLostConnectionTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientLostConnectionTests.java @@ -28,7 +28,7 @@ import io.modelcontextprotocol.spec.McpSchema; import reactor.test.StepVerifier; -@Timeout(15) +@Timeout(20) public class HttpSseMcpAsyncClientLostConnectionTests { private static final Logger logger = LoggerFactory.getLogger(HttpSseMcpAsyncClientLostConnectionTests.class); @@ -38,7 +38,7 @@ public class HttpSseMcpAsyncClientLostConnectionTests { // Uses the https://github.com/tzolov/mcp-everything-server-docker-image @SuppressWarnings("resource") - static GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") + static GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v3") .withCommand("node dist/index.js sse") .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) .withNetwork(network) @@ -98,10 +98,13 @@ McpAsyncClient client(McpClientTransport transport) { AtomicReference client = new AtomicReference<>(); assertThatCode(() -> { + // Do not advertise roots. Otherwise, the server will list roots during + // initialization. The client responds asynchronously, and there might be a + // rest condition in tests where we disconnect right after initialization. McpClient.AsyncSpec builder = McpClient.async(transport) .requestTimeout(Duration.ofSeconds(14)) .initializationTimeout(Duration.ofSeconds(2)) - .capabilities(McpSchema.ClientCapabilities.builder().roots(true).build()); + .capabilities(McpSchema.ClientCapabilities.builder().build()); client.set(builder.build()); }).doesNotThrowAnyException(); @@ -119,7 +122,7 @@ void withClient(McpClientTransport transport, Consumer c) { } @Test - void testPingWithEaxctExceptionType() { + void testPingWithExactExceptionType() { withClient(HttpClientSseClientTransport.builder(host).build(), mcpAsyncClient -> { StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete(); diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java similarity index 78% rename from mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java index 6cb3f7b65..f467289ff 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java @@ -4,6 +4,8 @@ package io.modelcontextprotocol.client; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Timeout; import org.testcontainers.containers.GenericContainer; import org.testcontainers.containers.wait.strategy.Wait; @@ -19,11 +21,11 @@ @Timeout(15) class HttpSseMcpAsyncClientTests extends AbstractMcpAsyncClientTests { - String host = "http://localhost:3004"; + private static String host = "http://localhost:3004"; // Uses the https://github.com/tzolov/mcp-everything-server-docker-image @SuppressWarnings("resource") - GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") + static GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v3") .withCommand("node dist/index.js sse") .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) .withExposedPorts(3001) @@ -34,15 +36,15 @@ protected McpClientTransport createMcpTransport() { return HttpClientSseClientTransport.builder(host).build(); } - @Override - protected void onStart() { + @BeforeAll + static void startContainer() { container.start(); int port = container.getMappedPort(3001); host = "http://" + container.getHost() + ":" + port; } - @Override - protected void onClose() { + @AfterAll + static void stopContainer() { container.stop(); } diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java new file mode 100644 index 000000000..483d38669 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java @@ -0,0 +1,78 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.client; + +import java.net.URI; +import java.util.Map; + +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.wait.strategy.Wait; + +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.client.transport.customizer.McpSyncHttpClientRequestCustomizer; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.spec.McpClientTransport; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isNull; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +/** + * Tests for the {@link McpSyncClient} with {@link HttpClientSseClientTransport}. + * + * @author Christian Tzolov + */ +@Timeout(15) // Giving extra time beyond the client timeout +class HttpSseMcpSyncClientTests extends AbstractMcpSyncClientTests { + + static String host = "http://localhost:3003"; + + // Uses the https://github.com/tzolov/mcp-everything-server-docker-image + @SuppressWarnings("resource") + static GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v3") + .withCommand("node dist/index.js sse") + .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) + .withExposedPorts(3001) + .waitingFor(Wait.forHttp("/").forStatusCode(404)); + + private final McpSyncHttpClientRequestCustomizer requestCustomizer = mock(McpSyncHttpClientRequestCustomizer.class); + + @Override + protected McpClientTransport createMcpTransport() { + return HttpClientSseClientTransport.builder(host).httpRequestCustomizer(requestCustomizer).build(); + } + + @BeforeAll + static void startContainer() { + container.start(); + int port = container.getMappedPort(3001); + host = "http://" + container.getHost() + ":" + port; + } + + @AfterAll + static void stopContainer() { + container.stop(); + } + + @Test + void customizesRequests() { + var mcpTransportContext = McpTransportContext.create(Map.of("some-key", "some-value")); + withClient(createMcpTransport(), syncSpec -> syncSpec.transportContextProvider(() -> mcpTransportContext), + mcpSyncClient -> { + mcpSyncClient.initialize(); + + verify(requestCustomizer, atLeastOnce()).customize(any(), eq("GET"), eq(URI.create(host + "/sse")), + isNull(), eq(mcpTransportContext)); + }); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/client/LifecycleInitializerPostInitializationHookTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/LifecycleInitializerPostInitializationHookTests.java new file mode 100644 index 000000000..6f7390f19 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/LifecycleInitializerPostInitializationHookTests.java @@ -0,0 +1,280 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.client; + +import java.time.Duration; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; + +import io.modelcontextprotocol.client.LifecycleInitializer.Initialization; +import io.modelcontextprotocol.spec.McpClientSession; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; +import reactor.test.StepVerifier; +import reactor.util.context.ContextView; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * Tests for {@link LifecycleInitializer} postInitializationHook functionality. + * + * @author Christian Tzolov + */ +class LifecycleInitializerPostInitializationHookTests { + + private static final Duration INITIALIZATION_TIMEOUT = Duration.ofSeconds(5); + + private static final McpSchema.ClientCapabilities CLIENT_CAPABILITIES = McpSchema.ClientCapabilities.builder() + .build(); + + private static final McpSchema.Implementation CLIENT_INFO = new McpSchema.Implementation("test-client", "1.0.0"); + + private static final List PROTOCOL_VERSIONS = List.of("1.0.0", "2.0.0"); + + private static final McpSchema.InitializeResult MOCK_INIT_RESULT = new McpSchema.InitializeResult("2.0.0", + McpSchema.ServerCapabilities.builder().build(), new McpSchema.Implementation("test-server", "1.0.0"), + "Test instructions"); + + @Mock + private McpClientSession mockClientSession; + + @Mock + private Function mockSessionSupplier; + + @Mock + private Function> mockPostInitializationHook; + + private LifecycleInitializer initializer; + + @BeforeEach + void setUp() { + MockitoAnnotations.openMocks(this); + + when(mockPostInitializationHook.apply(any(Initialization.class))).thenReturn(Mono.empty()); + when(mockSessionSupplier.apply(any(ContextView.class))).thenReturn(mockClientSession); + when(mockClientSession.sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any())) + .thenReturn(Mono.just(MOCK_INIT_RESULT)); + when(mockClientSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_INITIALIZED), any())) + .thenReturn(Mono.empty()); + when(mockClientSession.closeGracefully()).thenReturn(Mono.empty()); + + initializer = new LifecycleInitializer(CLIENT_CAPABILITIES, CLIENT_INFO, PROTOCOL_VERSIONS, + INITIALIZATION_TIMEOUT, mockSessionSupplier, mockPostInitializationHook); + } + + @Test + void shouldInvokePostInitializationHook() { + AtomicReference capturedInit = new AtomicReference<>(); + + when(mockPostInitializationHook.apply(any(Initialization.class))).thenAnswer(invocation -> { + capturedInit.set(invocation.getArgument(0)); + return Mono.empty(); + }); + + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) + .expectNext(MOCK_INIT_RESULT) + .verifyComplete(); + + // Verify hook was called + verify(mockPostInitializationHook, times(1)).apply(any(Initialization.class)); + + // Verify the hook received correct initialization data + assertThat(capturedInit.get()).isNotNull(); + assertThat(capturedInit.get().mcpSession()).isEqualTo(mockClientSession); + assertThat(capturedInit.get().initializeResult()).isEqualTo(MOCK_INIT_RESULT); + } + + @Test + void shouldInvokePostInitializationHookOnlyOnce() { + // First initialization + StepVerifier.create(initializer.withInitialization("test1", init -> Mono.just("result1"))) + .expectNext("result1") + .verifyComplete(); + + // Second call should reuse initialization and NOT call hook again + StepVerifier.create(initializer.withInitialization("test2", init -> Mono.just("result2"))) + .expectNext("result2") + .verifyComplete(); + + // Hook should only be called once + verify(mockPostInitializationHook, times(1)).apply(any(Initialization.class)); + } + + @Test + void shouldInvokePostInitializationHookOnlyOnceWithConcurrentRequests() { + AtomicInteger hookInvocationCount = new AtomicInteger(0); + + when(mockPostInitializationHook.apply(any(Initialization.class))).thenAnswer(invocation -> { + hookInvocationCount.incrementAndGet(); + return Mono.empty(); + }); + + // Start multiple concurrent initializations + Mono init1 = initializer.withInitialization("test1", init -> Mono.just("result1")) + .subscribeOn(Schedulers.parallel()); + Mono init2 = initializer.withInitialization("test2", init -> Mono.just("result2")) + .subscribeOn(Schedulers.parallel()); + Mono init3 = initializer.withInitialization("test3", init -> Mono.just("result3")) + .subscribeOn(Schedulers.parallel()); + + // TODO: can we assume the order of results? + StepVerifier.create(Mono.zip(init1, init2, init3)).assertNext(tuple -> { + assertThat(tuple.getT1()).isEqualTo("result1"); + assertThat(tuple.getT2()).isEqualTo("result2"); + assertThat(tuple.getT3()).isEqualTo("result3"); + }).verifyComplete(); + + // Hook should only be called once despite concurrent requests + assertThat(hookInvocationCount.get()).isEqualTo(1); + } + + @Test + void shouldFailInitializationWhenPostInitializationHookFails() { + RuntimeException hookError = new RuntimeException("Post-initialization hook failed"); + when(mockPostInitializationHook.apply(any(Initialization.class))).thenReturn(Mono.error(hookError)); + + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) + .expectErrorMatches(ex -> ex instanceof RuntimeException && ex.getCause() == hookError) + .verify(); + + // Verify initialization was not completed + assertThat(initializer.isInitialized()).isFalse(); + assertThat(initializer.currentInitializationResult()).isNull(); + + // Verify the hook was called + verify(mockPostInitializationHook, times(1)).apply(any(Initialization.class)); + } + + @Test + void shouldNotInvokePostInitializationHookWhenInitializationFails() { + when(mockClientSession.sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any())) + .thenReturn(Mono.error(new RuntimeException("Initialization failed"))); + + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) + .expectError(RuntimeException.class) + .verify(); + + // Hook should NOT be called when initialization fails + verify(mockPostInitializationHook, never()).apply(any(Initialization.class)); + } + + @Test + void shouldNotInvokePostInitializationHookWhenNotificationFails() { + when(mockClientSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_INITIALIZED), any())) + .thenReturn(Mono.error(new RuntimeException("Notification failed"))); + + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) + .expectError(RuntimeException.class) + .verify(); + + // Hook should NOT be called when notification fails + verify(mockPostInitializationHook, never()).apply(any(Initialization.class)); + } + + @Test + void shouldInvokePostInitializationHookAgainAfterReinitialization() { + AtomicInteger hookInvocationCount = new AtomicInteger(0); + + when(mockPostInitializationHook.apply(any(Initialization.class))).thenAnswer(invocation -> { + hookInvocationCount.incrementAndGet(); + return Mono.empty(); + }); + + // First initialization + StepVerifier.create(initializer.withInitialization("test1", init -> Mono.just("result1"))) + .expectNext("result1") + .verifyComplete(); + + assertThat(hookInvocationCount.get()).isEqualTo(1); + + // Simulate transport session exception to trigger re-initialization + initializer.handleException(new McpTransportSessionNotFoundException("Session lost")); + + // Hook should be called twice (once for each initialization) + assertThat(hookInvocationCount.get()).isEqualTo(2); + } + + @Test + void shouldAllowPostInitializationHookToPerformAsyncOperations() { + AtomicInteger operationCount = new AtomicInteger(0); + + when(mockPostInitializationHook.apply(any(Initialization.class))) + .thenReturn(Mono.fromRunnable(() -> operationCount.incrementAndGet()).then()); + + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) + .expectNext(MOCK_INIT_RESULT) + .verifyComplete(); + + // Verify the async operation was executed + assertThat(operationCount.get()).isEqualTo(1); + verify(mockPostInitializationHook, times(1)).apply(any(Initialization.class)); + } + + @Test + void shouldProvideCorrectInitializationDataToHook() { + AtomicReference capturedSession = new AtomicReference<>(); + AtomicReference capturedResult = new AtomicReference<>(); + + when(mockPostInitializationHook.apply(any(Initialization.class))).thenAnswer(invocation -> { + Initialization init = invocation.getArgument(0); + capturedSession.set(init.mcpSession()); + capturedResult.set(init.initializeResult()); + return Mono.empty(); + }); + + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) + .expectNext(MOCK_INIT_RESULT) + .verifyComplete(); + + // Verify the hook received the correct session and result + assertThat(capturedSession.get()).isEqualTo(mockClientSession); + assertThat(capturedResult.get()).isEqualTo(MOCK_INIT_RESULT); + assertThat(capturedResult.get().protocolVersion()).isEqualTo("2.0.0"); + assertThat(capturedResult.get().serverInfo().name()).isEqualTo("test-server"); + } + + @Test + void shouldInvokePostInitializationHookAfterSuccessfulInitialization() { + AtomicReference notificationSent = new AtomicReference<>(false); + AtomicReference hookCalledAfterNotification = new AtomicReference<>(false); + + when(mockClientSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_INITIALIZED), any())) + .thenAnswer(invocation -> { + notificationSent.set(true); + return Mono.empty(); + }); + + when(mockPostInitializationHook.apply(any(Initialization.class))).thenAnswer(invocation -> { + // Due to flatMap chaining in doInitialize, if the hook is called, + // the notification must have been sent first + hookCalledAfterNotification.set(notificationSent.get()); + return Mono.empty(); + }); + + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) + .expectNext(MOCK_INIT_RESULT) + .verifyComplete(); + + // Verify the hook was called and notification was already sent at that point + assertThat(hookCalledAfterNotification.get()).isTrue(); + verify(mockClientSession).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_INITIALIZED), any()); + verify(mockPostInitializationHook).apply(any(Initialization.class)); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/LifecycleInitializerTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/LifecycleInitializerTests.java similarity index 80% rename from mcp/src/test/java/io/modelcontextprotocol/client/LifecycleInitializerTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/client/LifecycleInitializerTests.java index 02021edbf..787ee9480 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/LifecycleInitializerTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/LifecycleInitializerTests.java @@ -10,14 +10,14 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; +import io.modelcontextprotocol.client.LifecycleInitializer.Initialization; +import io.modelcontextprotocol.spec.McpClientSession; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.Mock; import org.mockito.MockitoAnnotations; - -import io.modelcontextprotocol.spec.McpClientSession; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException; import reactor.core.publisher.Mono; import reactor.core.scheduler.Schedulers; import reactor.test.StepVerifier; @@ -58,12 +58,16 @@ class LifecycleInitializerTests { @Mock private Function mockSessionSupplier; + @Mock + private Function> mockPostInitializationHook; + private LifecycleInitializer initializer; @BeforeEach void setUp() { MockitoAnnotations.openMocks(this); + when(mockPostInitializationHook.apply(any(Initialization.class))).thenReturn(Mono.empty()); when(mockSessionSupplier.apply(any(ContextView.class))).thenReturn(mockClientSession); when(mockClientSession.sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any())) .thenReturn(Mono.just(MOCK_INIT_RESULT)); @@ -72,45 +76,45 @@ void setUp() { when(mockClientSession.closeGracefully()).thenReturn(Mono.empty()); initializer = new LifecycleInitializer(CLIENT_CAPABILITIES, CLIENT_INFO, PROTOCOL_VERSIONS, - INITIALIZATION_TIMEOUT, mockSessionSupplier); + INITIALIZATION_TIMEOUT, mockSessionSupplier, mockPostInitializationHook); } @Test void constructorShouldValidateParameters() { assertThatThrownBy(() -> new LifecycleInitializer(null, CLIENT_INFO, PROTOCOL_VERSIONS, INITIALIZATION_TIMEOUT, - mockSessionSupplier)) + mockSessionSupplier, mockPostInitializationHook)) .isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("Client capabilities must not be null"); assertThatThrownBy(() -> new LifecycleInitializer(CLIENT_CAPABILITIES, null, PROTOCOL_VERSIONS, - INITIALIZATION_TIMEOUT, mockSessionSupplier)) + INITIALIZATION_TIMEOUT, mockSessionSupplier, mockPostInitializationHook)) .isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("Client info must not be null"); assertThatThrownBy(() -> new LifecycleInitializer(CLIENT_CAPABILITIES, CLIENT_INFO, null, - INITIALIZATION_TIMEOUT, mockSessionSupplier)) + INITIALIZATION_TIMEOUT, mockSessionSupplier, mockPostInitializationHook)) .isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("Protocol versions must not be empty"); assertThatThrownBy(() -> new LifecycleInitializer(CLIENT_CAPABILITIES, CLIENT_INFO, List.of(), - INITIALIZATION_TIMEOUT, mockSessionSupplier)) + INITIALIZATION_TIMEOUT, mockSessionSupplier, mockPostInitializationHook)) .isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("Protocol versions must not be empty"); assertThatThrownBy(() -> new LifecycleInitializer(CLIENT_CAPABILITIES, CLIENT_INFO, PROTOCOL_VERSIONS, null, - mockSessionSupplier)) + mockSessionSupplier, mockPostInitializationHook)) .isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("Initialization timeout must not be null"); assertThatThrownBy(() -> new LifecycleInitializer(CLIENT_CAPABILITIES, CLIENT_INFO, PROTOCOL_VERSIONS, - INITIALIZATION_TIMEOUT, null)) + INITIALIZATION_TIMEOUT, null, mockPostInitializationHook)) .isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("Session supplier must not be null"); } @Test void shouldInitializeSuccessfully() { - StepVerifier.create(initializer.withIntitialization("test", init -> Mono.just(init.initializeResult()))) + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) .assertNext(result -> { assertThat(result).isEqualTo(MOCK_INIT_RESULT); assertThat(initializer.isInitialized()).isTrue(); @@ -132,7 +136,7 @@ void shouldUseLatestProtocolVersionInInitializeRequest() { return Mono.just(MOCK_INIT_RESULT); }); - StepVerifier.create(initializer.withIntitialization("test", init -> Mono.just(init.initializeResult()))) + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) .assertNext(result -> { assertThat(capturedRequest.get().protocolVersion()).isEqualTo("2.0.0"); // Latest // version @@ -152,7 +156,7 @@ void shouldFailForUnsupportedProtocolVersion() { when(mockClientSession.sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any())) .thenReturn(Mono.just(unsupportedResult)); - StepVerifier.create(initializer.withIntitialization("test", init -> Mono.just(init.initializeResult()))) + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) .expectError(RuntimeException.class) .verify(); @@ -167,13 +171,13 @@ void shouldTimeoutOnSlowInitialization() { Duration SLOW_RESPONSE_DELAY = Duration.ofSeconds(5); LifecycleInitializer shortTimeoutInitializer = new LifecycleInitializer(CLIENT_CAPABILITIES, CLIENT_INFO, - PROTOCOL_VERSIONS, INITIALIZE_TIMEOUT, mockSessionSupplier); + PROTOCOL_VERSIONS, INITIALIZE_TIMEOUT, mockSessionSupplier, mockPostInitializationHook); when(mockClientSession.sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any())) .thenReturn(Mono.just(MOCK_INIT_RESULT).delayElement(SLOW_RESPONSE_DELAY, virtualTimeScheduler)); StepVerifier - .withVirtualTime(() -> shortTimeoutInitializer.withIntitialization("test", + .withVirtualTime(() -> shortTimeoutInitializer.withInitialization("test", init -> Mono.just(init.initializeResult())), () -> virtualTimeScheduler, Long.MAX_VALUE) .expectSubscription() .expectNoEvent(INITIALIZE_TIMEOUT) @@ -184,12 +188,12 @@ void shouldTimeoutOnSlowInitialization() { @Test void shouldReuseExistingInitialization() { // First initialization - StepVerifier.create(initializer.withIntitialization("test1", init -> Mono.just("result1"))) + StepVerifier.create(initializer.withInitialization("test1", init -> Mono.just("result1"))) .expectNext("result1") .verifyComplete(); // Second call should reuse the same initialization - StepVerifier.create(initializer.withIntitialization("test2", init -> Mono.just("result2"))) + StepVerifier.create(initializer.withInitialization("test2", init -> Mono.just("result2"))) .expectNext("result2") .verifyComplete(); @@ -209,11 +213,11 @@ void shouldHandleConcurrentInitializationRequests() { // Start multiple concurrent initializations using subscribeOn with parallel // scheduler - Mono init1 = initializer.withIntitialization("test1", init -> Mono.just("result1")) + Mono init1 = initializer.withInitialization("test1", init -> Mono.just("result1")) .subscribeOn(Schedulers.parallel()); - Mono init2 = initializer.withIntitialization("test2", init -> Mono.just("result2")) + Mono init2 = initializer.withInitialization("test2", init -> Mono.just("result2")) .subscribeOn(Schedulers.parallel()); - Mono init3 = initializer.withIntitialization("test3", init -> Mono.just("result3")) + Mono init3 = initializer.withInitialization("test3", init -> Mono.just("result3")) .subscribeOn(Schedulers.parallel()); StepVerifier.create(Mono.zip(init1, init2, init3)).assertNext(tuple -> { @@ -230,20 +234,32 @@ void shouldHandleConcurrentInitializationRequests() { @Test void shouldHandleInitializationFailure() { when(mockClientSession.sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any())) - .thenReturn(Mono.error(new RuntimeException("Connection failed"))); + // fail once + .thenReturn(Mono.error(new RuntimeException("Connection failed"))) + // succeeds on the second call + .thenReturn(Mono.just(MOCK_INIT_RESULT)); - StepVerifier.create(initializer.withIntitialization("test", init -> Mono.just(init.initializeResult()))) + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) .expectError(RuntimeException.class) .verify(); assertThat(initializer.isInitialized()).isFalse(); assertThat(initializer.currentInitializationResult()).isNull(); + + // The initializer can recover from previous errors + StepVerifier + .create(initializer.withInitialization("successful init", init -> Mono.just(init.initializeResult()))) + .expectNext(MOCK_INIT_RESULT) + .verifyComplete(); + + assertThat(initializer.isInitialized()).isTrue(); + assertThat(initializer.currentInitializationResult()).isEqualTo(MOCK_INIT_RESULT); } @Test void shouldHandleTransportSessionNotFoundException() { // successful initialization first - StepVerifier.create(initializer.withIntitialization("test", init -> Mono.just(init.initializeResult()))) + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) .expectNext(MOCK_INIT_RESULT) .verifyComplete(); @@ -265,7 +281,7 @@ void shouldHandleTransportSessionNotFoundException() { @Test void shouldHandleOtherExceptions() { // Simulate a successful initialization first - StepVerifier.create(initializer.withIntitialization("test", init -> Mono.just(init.initializeResult()))) + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) .expectNext(MOCK_INIT_RESULT) .verifyComplete(); @@ -283,7 +299,7 @@ void shouldHandleOtherExceptions() { @Test void shouldCloseGracefully() { - StepVerifier.create(initializer.withIntitialization("test", init -> Mono.just(init.initializeResult()))) + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) .expectNext(MOCK_INIT_RESULT) .verifyComplete(); @@ -295,7 +311,7 @@ void shouldCloseGracefully() { @Test void shouldCloseImmediately() { - StepVerifier.create(initializer.withIntitialization("test", init -> Mono.just(init.initializeResult()))) + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) .expectNext(MOCK_INIT_RESULT) .verifyComplete(); @@ -330,7 +346,7 @@ void shouldSetProtocolVersionsForTesting() { new McpSchema.Implementation("test-server", "1.0.0"), "Test instructions")); }); - StepVerifier.create(initializer.withIntitialization("test", init -> Mono.just(init.initializeResult()))) + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) .assertNext(result -> { // Latest from new versions assertThat(capturedRequest.get().protocolVersion()).isEqualTo("4.0.0"); @@ -351,7 +367,7 @@ void shouldPassContextToSessionSupplier() { }); StepVerifier - .create(initializer.withIntitialization("test", init -> Mono.just(init.initializeResult())) + .create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult())) .contextWrite(Context.of(contextKey, contextValue))) .expectNext(MOCK_INIT_RESULT) .verifyComplete(); @@ -362,7 +378,7 @@ void shouldPassContextToSessionSupplier() { @Test void shouldProvideAccessToMcpSessionAndInitializeResult() { - StepVerifier.create(initializer.withIntitialization("test", init -> { + StepVerifier.create(initializer.withInitialization("test", init -> { assertThat(init.mcpSession()).isEqualTo(mockClientSession); assertThat(init.initializeResult()).isEqualTo(MOCK_INIT_RESULT); return Mono.just("success"); @@ -374,7 +390,7 @@ void shouldHandleNotificationFailure() { when(mockClientSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_INITIALIZED), any())) .thenReturn(Mono.error(new RuntimeException("Notification failed"))); - StepVerifier.create(initializer.withIntitialization("test", init -> Mono.just(init.initializeResult()))) + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) .expectError(RuntimeException.class) .verify(); @@ -391,7 +407,7 @@ void shouldReturnNullWhenNotInitialized() { @Test void shouldReinitializeAfterTransportSessionException() { // First initialization - StepVerifier.create(initializer.withIntitialization("test1", init -> Mono.just("result1"))) + StepVerifier.create(initializer.withInitialization("test1", init -> Mono.just("result1"))) .expectNext("result1") .verifyComplete(); @@ -399,7 +415,7 @@ void shouldReinitializeAfterTransportSessionException() { initializer.handleException(new McpTransportSessionNotFoundException("Session lost")); // Should be able to initialize again - StepVerifier.create(initializer.withIntitialization("test2", init -> Mono.just("result2"))) + StepVerifier.create(initializer.withInitialization("test2", init -> Mono.just("result2"))) .expectNext("result2") .verifyComplete(); diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java similarity index 96% rename from mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java index cab847512..612a65898 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java @@ -4,14 +4,13 @@ package io.modelcontextprotocol.client; +import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.function.Function; -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.json.TypeRef; import io.modelcontextprotocol.MockMcpClientTransport; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; @@ -24,6 +23,7 @@ import reactor.core.publisher.Mono; import static io.modelcontextprotocol.spec.McpSchema.METHOD_INITIALIZE; +import static io.modelcontextprotocol.util.McpJsonMapperUtils.JSON_MAPPER; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -93,7 +93,7 @@ void testSuccessfulInitialization() { } @Test - void testToolsChangeNotificationHandling() throws JsonProcessingException { + void testToolsChangeNotificationHandling() throws IOException { MockMcpClientTransport transport = initializationEnabledTransport(); // Create a list to store received tools for verification @@ -110,8 +110,11 @@ void testToolsChangeNotificationHandling() throws JsonProcessingException { // Create a mock tools list that the server will return Map inputSchema = Map.of("type", "object", "properties", Map.of(), "required", List.of()); - McpSchema.Tool mockTool = new McpSchema.Tool("test-tool-1", "Test Tool 1 Description", - new ObjectMapper().writeValueAsString(inputSchema)); + McpSchema.Tool mockTool = McpSchema.Tool.builder() + .name("test-tool-1") + .description("Test Tool 1 Description") + .inputSchema(JSON_MAPPER, JSON_MAPPER.writeValueAsString(inputSchema)) + .build(); // Create page 1 response with nextPageToken String nextPageToken = "page2Token"; @@ -131,9 +134,11 @@ void testToolsChangeNotificationHandling() throws JsonProcessingException { transport.simulateIncomingMessage(toolsListResponse1); // Create mock tools for page 2 - McpSchema.Tool mockTool2 = new McpSchema.Tool("test-tool-2", "Test Tool 2 Description", - new ObjectMapper().writeValueAsString(inputSchema)); - + McpSchema.Tool mockTool2 = McpSchema.Tool.builder() + .name("test-tool-2") + .description("Test Tool 2 Description") + .inputSchema(JSON_MAPPER, JSON_MAPPER.writeValueAsString(inputSchema)) + .build(); // Create page 2 response with no nextPageToken (last page) McpSchema.ListToolsResult mockToolsResult2 = new McpSchema.ListToolsResult(List.of(mockTool2), null); @@ -321,7 +326,7 @@ void testSamplingCreateMessageRequestHandling() { assertThat(response.error()).isNull(); McpSchema.CreateMessageResult result = transport.unmarshalFrom(response.result(), - new TypeReference() { + new TypeRef() { }); assertThat(result).isNotNull(); assertThat(result.role()).isEqualTo(McpSchema.Role.ASSISTANT); @@ -425,7 +430,7 @@ void testElicitationCreateRequestHandling() { assertThat(response.id()).isEqualTo("test-id"); assertThat(response.error()).isNull(); - McpSchema.ElicitResult result = transport.unmarshalFrom(response.result(), new TypeReference<>() { + McpSchema.ElicitResult result = transport.unmarshalFrom(response.result(), new TypeRef<>() { }); assertThat(result).isNotNull(); assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); @@ -470,7 +475,7 @@ void testElicitationFailRequestHandling(McpSchema.ElicitResult.Action action) { assertThat(response.id()).isEqualTo("test-id"); assertThat(response.error()).isNull(); - McpSchema.ElicitResult result = transport.unmarshalFrom(response.result(), new TypeReference<>() { + McpSchema.ElicitResult result = transport.unmarshalFrom(response.result(), new TypeRef<>() { }); assertThat(result).isNotNull(); assertThat(result.action()).isEqualTo(action); @@ -551,4 +556,4 @@ void testPingMessageRequestHandling() { asyncMcpClient.closeGracefully(); } -} \ No newline at end of file +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/client/McpAsyncClientTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/McpAsyncClientTests.java new file mode 100644 index 000000000..48bf1da5b --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/McpAsyncClientTests.java @@ -0,0 +1,310 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.client; + +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; +import java.util.stream.Collectors; + +import io.modelcontextprotocol.json.TypeRef; +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.ProtocolVersions; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import static io.modelcontextprotocol.util.McpJsonMapperUtils.JSON_MAPPER; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; + +class McpAsyncClientTests { + + public static final McpSchema.Implementation MOCK_SERVER_INFO = new McpSchema.Implementation("test-server", + "1.0.0"); + + public static final McpSchema.ServerCapabilities MOCK_SERVER_CAPABILITIES = McpSchema.ServerCapabilities.builder() + .tools(true) + .build(); + + public static final McpSchema.InitializeResult MOCK_INIT_RESULT = new McpSchema.InitializeResult( + ProtocolVersions.MCP_2024_11_05, MOCK_SERVER_CAPABILITIES, MOCK_SERVER_INFO, "Test instructions"); + + private static final String CONTEXT_KEY = "context.key"; + + private McpClientTransport createMockTransportForToolValidation(boolean hasOutputSchema, boolean invalidOutput) { + + // Create tool with or without output schema + Map inputSchemaMap = Map.of("type", "object", "properties", + Map.of("expression", Map.of("type", "string")), "required", List.of("expression")); + + McpSchema.JsonSchema inputSchema = new McpSchema.JsonSchema("object", inputSchemaMap, null, null, null, null); + McpSchema.Tool.Builder toolBuilder = McpSchema.Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .inputSchema(inputSchema); + + if (hasOutputSchema) { + Map outputSchema = Map.of("type", "object", "properties", + Map.of("result", Map.of("type", "number"), "operation", Map.of("type", "string")), "required", + List.of("result", "operation")); + toolBuilder.outputSchema(outputSchema); + } + + McpSchema.Tool calculatorTool = toolBuilder.build(); + McpSchema.ListToolsResult mockToolsResult = new McpSchema.ListToolsResult(List.of(calculatorTool), null); + + // Create call tool result - valid or invalid based on parameter + Map structuredContent = invalidOutput ? Map.of("result", "5", "operation", "add") + : Map.of("result", 5, "operation", "add"); + + McpSchema.CallToolResult mockCallToolResult = McpSchema.CallToolResult.builder() + .addTextContent("Calculation result") + .structuredContent(structuredContent) + .build(); + + return new McpClientTransport() { + Function, Mono> handler; + + @Override + public Mono connect( + Function, Mono> handler) { + this.handler = handler; + return Mono.empty(); + } + + @Override + public Mono closeGracefully() { + return Mono.empty(); + } + + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message) { + if (!(message instanceof McpSchema.JSONRPCRequest request)) { + return Mono.empty(); + } + + McpSchema.JSONRPCResponse response; + if (McpSchema.METHOD_INITIALIZE.equals(request.method())) { + response = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), MOCK_INIT_RESULT, + null); + } + else if (McpSchema.METHOD_TOOLS_LIST.equals(request.method())) { + response = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), mockToolsResult, + null); + } + else if (McpSchema.METHOD_TOOLS_CALL.equals(request.method())) { + response = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), + mockCallToolResult, null); + } + else { + return Mono.empty(); + } + + return handler.apply(Mono.just(response)).then(); + } + + @Override + public T unmarshalFrom(Object data, TypeRef typeRef) { + return JSON_MAPPER.convertValue(data, new TypeRef<>() { + @Override + public java.lang.reflect.Type getType() { + return typeRef.getType(); + } + }); + } + }; + } + + @Test + void validateContextPassedToTransportConnect() { + McpClientTransport transport = new McpClientTransport() { + Function, Mono> handler; + + final AtomicReference contextValue = new AtomicReference<>(); + + @Override + public Mono connect( + Function, Mono> handler) { + return Mono.deferContextual(ctx -> { + this.handler = handler; + if (ctx.hasKey(CONTEXT_KEY)) { + this.contextValue.set(ctx.get(CONTEXT_KEY)); + } + return Mono.empty(); + }); + } + + @Override + public Mono closeGracefully() { + return Mono.empty(); + } + + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message) { + if (!"hello".equals(this.contextValue.get())) { + return Mono.error(new RuntimeException("Context value not propagated via #connect method")); + } + // We're only interested in handling the init request to provide an init + // response + if (!(message instanceof McpSchema.JSONRPCRequest)) { + return Mono.empty(); + } + McpSchema.JSONRPCResponse initResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, + ((McpSchema.JSONRPCRequest) message).id(), MOCK_INIT_RESULT, null); + return handler.apply(Mono.just(initResponse)).then(); + } + + @Override + public T unmarshalFrom(Object data, TypeRef typeRef) { + return JSON_MAPPER.convertValue(data, new TypeRef<>() { + @Override + public java.lang.reflect.Type getType() { + return typeRef.getType(); + } + }); + } + }; + + assertThatCode(() -> { + McpAsyncClient client = McpClient.async(transport).build(); + client.initialize().contextWrite(ctx -> ctx.put(CONTEXT_KEY, "hello")).block(); + }).doesNotThrowAnyException(); + } + + @Test + void testCallToolWithOutputSchemaValidationSuccess() { + McpClientTransport transport = createMockTransportForToolValidation(true, false); + + McpAsyncClient client = McpClient.async(transport).enableCallToolSchemaCaching(true).build(); + + StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete(); + + StepVerifier.create(client.callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3")))) + .expectNextMatches(response -> { + assertThat(response).isNotNull(); + assertThat(response.isError()).isFalse(); + assertThat(response.structuredContent()).isInstanceOf(Map.class); + assertThat((Map) response.structuredContent()).hasSize(2); + assertThat(response.content()).hasSize(1); + return true; + }) + .verifyComplete(); + + StepVerifier.create(client.closeGracefully()).verifyComplete(); + } + + @Test + void testCallToolWithNoOutputSchemaSuccess() { + McpClientTransport transport = createMockTransportForToolValidation(false, false); + + McpAsyncClient client = McpClient.async(transport).enableCallToolSchemaCaching(true).build(); + + StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete(); + + StepVerifier.create(client.callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3")))) + .expectNextMatches(response -> { + assertThat(response).isNotNull(); + assertThat(response.isError()).isFalse(); + assertThat(response.structuredContent()).isInstanceOf(Map.class); + assertThat((Map) response.structuredContent()).hasSize(2); + assertThat(response.content()).hasSize(1); + return true; + }) + .verifyComplete(); + + StepVerifier.create(client.closeGracefully()).verifyComplete(); + } + + @Test + void testCallToolWithOutputSchemaValidationFailure() { + McpClientTransport transport = createMockTransportForToolValidation(true, true); + + McpAsyncClient client = McpClient.async(transport).enableCallToolSchemaCaching(true).build(); + + StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete(); + + StepVerifier.create(client.callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3")))) + .expectErrorMatches(ex -> ex instanceof IllegalArgumentException + && ex.getMessage().contains("Tool call result validation failed")) + .verify(); + + StepVerifier.create(client.closeGracefully()).verifyComplete(); + } + + @Test + void testListToolsWithEmptyCursor() { + McpSchema.Tool addTool = McpSchema.Tool.builder().name("add").description("calculate add").build(); + McpSchema.Tool subtractTool = McpSchema.Tool.builder() + .name("subtract") + .description("calculate subtract") + .build(); + McpSchema.ListToolsResult mockToolsResult = new McpSchema.ListToolsResult(List.of(addTool, subtractTool), ""); + + McpClientTransport transport = new McpClientTransport() { + Function, Mono> handler; + + @Override + public Mono connect( + Function, Mono> handler) { + return Mono.deferContextual(ctx -> { + this.handler = handler; + return Mono.empty(); + }); + } + + @Override + public Mono closeGracefully() { + return Mono.empty(); + } + + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message) { + if (!(message instanceof McpSchema.JSONRPCRequest request)) { + return Mono.empty(); + } + + McpSchema.JSONRPCResponse response; + if (McpSchema.METHOD_INITIALIZE.equals(request.method())) { + response = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), MOCK_INIT_RESULT, + null); + } + else if (McpSchema.METHOD_TOOLS_LIST.equals(request.method())) { + response = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), mockToolsResult, + null); + } + else { + return Mono.empty(); + } + + return handler.apply(Mono.just(response)).then(); + } + + @Override + public T unmarshalFrom(Object data, TypeRef typeRef) { + return JSON_MAPPER.convertValue(data, new TypeRef<>() { + @Override + public java.lang.reflect.Type getType() { + return typeRef.getType(); + } + }); + } + }; + + McpAsyncClient client = McpClient.async(transport).enableCallToolSchemaCaching(true).build(); + + Mono mono = client.listTools(); + McpSchema.ListToolsResult toolsResult = mono.block(); + assertThat(toolsResult).isNotNull(); + + Set names = toolsResult.tools().stream().map(McpSchema.Tool::name).collect(Collectors.toSet()); + assertThat(names).containsExactlyInAnyOrder("subtract", "add"); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/McpClientProtocolVersionTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/McpClientProtocolVersionTests.java similarity index 92% rename from mcp/src/test/java/io/modelcontextprotocol/client/McpClientProtocolVersionTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/client/McpClientProtocolVersionTests.java index 3feb1d05c..a94b9b6a7 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/McpClientProtocolVersionTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/McpClientProtocolVersionTests.java @@ -8,9 +8,9 @@ import java.util.List; import io.modelcontextprotocol.MockMcpClientTransport; -import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.InitializeResult; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; import org.junit.jupiter.api.Test; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; @@ -22,7 +22,7 @@ */ class McpClientProtocolVersionTests { - private static final Duration REQUEST_TIMEOUT = Duration.ofSeconds(30); + private static final Duration REQUEST_TIMEOUT = Duration.ofSeconds(300); private static final McpSchema.Implementation CLIENT_INFO = new McpSchema.Implementation("test-client", "1.0.0"); @@ -46,13 +46,12 @@ void shouldUseLatestVersionByDefault() { assertThat(initRequest.protocolVersion()).isEqualTo(transport.protocolVersions().get(0)); transport.simulateIncomingMessage(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), - new McpSchema.InitializeResult(protocolVersion, null, + new McpSchema.InitializeResult(protocolVersion, ServerCapabilities.builder().build(), new McpSchema.Implementation("test-server", "1.0.0"), null), null)); }).assertNext(result -> { assertThat(result.protocolVersion()).isEqualTo(protocolVersion); }).verifyComplete(); - } finally { // Ensure cleanup happens even if test fails @@ -81,7 +80,7 @@ void shouldNegotiateSpecificVersion() { assertThat(initRequest.protocolVersion()).isIn(List.of(oldVersion, McpSchema.LATEST_PROTOCOL_VERSION)); transport.simulateIncomingMessage(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), - new McpSchema.InitializeResult(oldVersion, null, + new McpSchema.InitializeResult(oldVersion, ServerCapabilities.builder().build(), new McpSchema.Implementation("test-server", "1.0.0"), null), null)); }).assertNext(result -> { @@ -110,7 +109,7 @@ void shouldFailForUnsupportedVersion() { assertThat(request.params()).isInstanceOf(McpSchema.InitializeRequest.class); transport.simulateIncomingMessage(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), - new McpSchema.InitializeResult(unsupportedVersion, null, + new McpSchema.InitializeResult(unsupportedVersion, ServerCapabilities.builder().build(), new McpSchema.Implementation("test-server", "1.0.0"), null), null)); }).expectError(RuntimeException.class).verify(); @@ -143,7 +142,7 @@ void shouldUseHighestVersionWhenMultipleSupported() { assertThat(initRequest.protocolVersion()).isEqualTo(latestVersion); transport.simulateIncomingMessage(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), - new McpSchema.InitializeResult(latestVersion, null, + new McpSchema.InitializeResult(latestVersion, ServerCapabilities.builder().build(), new McpSchema.Implementation("test-server", "1.0.0"), null), null)); }).assertNext(result -> { diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/client/ServerParameterUtils.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/ServerParameterUtils.java new file mode 100644 index 000000000..63ec015fe --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/ServerParameterUtils.java @@ -0,0 +1,19 @@ +package io.modelcontextprotocol.client; + +import io.modelcontextprotocol.client.transport.ServerParameters; + +public final class ServerParameterUtils { + + private ServerParameterUtils() { + } + + public static ServerParameters createServerParameters() { + if (System.getProperty("os.name").toLowerCase().contains("win")) { + return ServerParameters.builder("cmd.exe") + .args("/c", "npx.cmd", "-y", "@modelcontextprotocol/server-everything", "stdio") + .build(); + } + return ServerParameters.builder("npx").args("-y", "@modelcontextprotocol/server-everything", "stdio").build(); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java similarity index 50% rename from mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java index e9356d0c0..aa8aaa397 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java @@ -11,33 +11,37 @@ import io.modelcontextprotocol.spec.McpClientTransport; import org.junit.jupiter.api.Timeout; +import static io.modelcontextprotocol.client.ServerParameterUtils.createServerParameters; +import static io.modelcontextprotocol.util.McpJsonMapperUtils.JSON_MAPPER; + /** * Tests for the {@link McpAsyncClient} with {@link StdioClientTransport}. * + *

+ * These tests use npx to download and run the MCP "everything" server locally. The first + * test execution will download the everything server scripts and cache them locally, + * which can take more than 15 seconds. Subsequent test runs will use the cached version + * and execute faster. + * * @author Christian Tzolov * @author Dariusz Jędrzejczyk */ -@Timeout(15) // Giving extra time beyond the client timeout +@Timeout(25) // Giving extra time beyond the client timeout to account for initial server + // download class StdioMcpAsyncClientTests extends AbstractMcpAsyncClientTests { @Override protected McpClientTransport createMcpTransport() { - ServerParameters stdioParams; - if (System.getProperty("os.name").toLowerCase().contains("win")) { - stdioParams = ServerParameters.builder("cmd.exe") - .args("/c", "npx.cmd", "-y", "@modelcontextprotocol/server-everything", "stdio") - .build(); - } - else { - stdioParams = ServerParameters.builder("npx") - .args("-y", "@modelcontextprotocol/server-everything", "stdio") - .build(); - } - return new StdioClientTransport(stdioParams); + return new StdioClientTransport(createServerParameters(), JSON_MAPPER); } protected Duration getInitializationTimeout() { return Duration.ofSeconds(20); } + @Override + protected Duration getRequestTimeout() { + return Duration.ofSeconds(25); + } + } diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java similarity index 70% rename from mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java index 4b5f4f9c0..b1e567989 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java @@ -17,31 +17,30 @@ import reactor.core.publisher.Sinks; import reactor.test.StepVerifier; +import static io.modelcontextprotocol.client.ServerParameterUtils.createServerParameters; +import static io.modelcontextprotocol.util.McpJsonMapperUtils.JSON_MAPPER; import static org.assertj.core.api.Assertions.assertThat; /** * Tests for the {@link McpSyncClient} with {@link StdioClientTransport}. * + *

+ * These tests use npx to download and run the MCP "everything" server locally. The first + * test execution will download the everything server scripts and cache them locally, + * which can take more than 15 seconds. Subsequent test runs will use the cached version + * and execute faster. + * * @author Christian Tzolov * @author Dariusz Jędrzejczyk */ -@Timeout(15) // Giving extra time beyond the client timeout +@Timeout(25) // Giving extra time beyond the client timeout to account for initial server + // download class StdioMcpSyncClientTests extends AbstractMcpSyncClientTests { @Override protected McpClientTransport createMcpTransport() { - ServerParameters stdioParams; - if (System.getProperty("os.name").toLowerCase().contains("win")) { - stdioParams = ServerParameters.builder("cmd.exe") - .args("/c", "npx.cmd", "-y", "@modelcontextprotocol/server-everything", "stdio") - .build(); - } - else { - stdioParams = ServerParameters.builder("npx") - .args("-y", "@modelcontextprotocol/server-everything", "stdio") - .build(); - } - return new StdioClientTransport(stdioParams); + ServerParameters stdioParams = createServerParameters(); + return new StdioClientTransport(stdioParams, JSON_MAPPER); } @Test @@ -71,4 +70,9 @@ protected Duration getInitializationTimeout() { return Duration.ofSeconds(10); } + @Override + protected Duration getRequestTimeout() { + return Duration.ofSeconds(25); + } + } diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java similarity index 89% rename from mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java index 46b9207f6..c5c365798 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java @@ -14,9 +14,12 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; -import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.client.transport.customizer.McpAsyncHttpClientRequestCustomizer; +import io.modelcontextprotocol.client.transport.customizer.McpSyncHttpClientRequestCustomizer; +import io.modelcontextprotocol.common.McpTransportContext; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest; + import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeAll; @@ -33,6 +36,7 @@ import org.springframework.http.codec.ServerSentEvent; import org.springframework.web.util.UriComponentsBuilder; +import static io.modelcontextprotocol.util.McpJsonMapperUtils.JSON_MAPPER; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; import static org.mockito.ArgumentMatchers.any; @@ -54,7 +58,7 @@ class HttpClientSseClientTransportTests { static String host = "http://localhost:3001"; @SuppressWarnings("resource") - static GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") + static GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v3") .withCommand("node dist/index.js sse") .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) .withExposedPorts(3001) @@ -62,6 +66,8 @@ class HttpClientSseClientTransportTests { private TestHttpClientSseClientTransport transport; + private final McpTransportContext context = McpTransportContext.create(Map.of("some-key", "some-value")); + // Test class to access protected methods static class TestHttpClientSseClientTransport extends HttpClientSseClientTransport { @@ -71,8 +77,8 @@ static class TestHttpClientSseClientTransport extends HttpClientSseClientTranspo public TestHttpClientSseClientTransport(final String baseUri) { super(HttpClient.newBuilder().version(HttpClient.Version.HTTP_1_1).build(), - HttpRequest.newBuilder().header("Content-Type", "application/json"), baseUri, "/sse", - new ObjectMapper(), AsyncHttpRequestCustomizer.NOOP); + HttpRequest.newBuilder().header("Content-Type", "application/json"), baseUri, "/sse", JSON_MAPPER, + McpAsyncHttpClientRequestCustomizer.NOOP); } public int getInboundMessageCount() { @@ -389,7 +395,7 @@ void testChainedCustomizations() { @Test void testRequestCustomizer() { - var mockCustomizer = mock(SyncHttpRequestCustomizer.class); + var mockCustomizer = mock(McpSyncHttpClientRequestCustomizer.class); // Create a transport with the customizer var customizedTransport = HttpClientSseClientTransport.builder(host) @@ -397,11 +403,14 @@ void testRequestCustomizer() { .build(); // Connect - StepVerifier.create(customizedTransport.connect(Function.identity())).verifyComplete(); + StepVerifier + .create(customizedTransport.connect(Function.identity()) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, context))) + .verifyComplete(); // Verify the customizer was called verify(mockCustomizer).customize(any(), eq("GET"), - eq(UriComponentsBuilder.fromUriString(host).path("/sse").build().toUri()), isNull()); + eq(UriComponentsBuilder.fromUriString(host).path("/sse").build().toUri()), isNull(), eq(context)); clearInvocations(mockCustomizer); // Send test message @@ -409,12 +418,16 @@ void testRequestCustomizer() { Map.of("key", "value")); // Subscribe to messages and verify - StepVerifier.create(customizedTransport.sendMessage(testMessage)).verifyComplete(); + StepVerifier + .create(customizedTransport.sendMessage(testMessage) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, context))) + .verifyComplete(); // Verify the customizer was called var uriArgumentCaptor = ArgumentCaptor.forClass(URI.class); verify(mockCustomizer).customize(any(), eq("POST"), uriArgumentCaptor.capture(), eq( - "{\"jsonrpc\":\"2.0\",\"method\":\"test-method\",\"id\":\"test-id\",\"params\":{\"key\":\"value\"}}")); + "{\"jsonrpc\":\"2.0\",\"method\":\"test-method\",\"id\":\"test-id\",\"params\":{\"key\":\"value\"}}"), + eq(context)); assertThat(uriArgumentCaptor.getValue().toString()).startsWith(host + "/message?sessionId="); // Clean up @@ -423,8 +436,8 @@ void testRequestCustomizer() { @Test void testAsyncRequestCustomizer() { - var mockCustomizer = mock(AsyncHttpRequestCustomizer.class); - when(mockCustomizer.customize(any(), any(), any(), any())) + var mockCustomizer = mock(McpAsyncHttpClientRequestCustomizer.class); + when(mockCustomizer.customize(any(), any(), any(), any(), any())) .thenAnswer(invocation -> Mono.just(invocation.getArguments()[0])); // Create a transport with the customizer @@ -433,11 +446,14 @@ void testAsyncRequestCustomizer() { .build(); // Connect - StepVerifier.create(customizedTransport.connect(Function.identity())).verifyComplete(); + StepVerifier + .create(customizedTransport.connect(Function.identity()) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, context))) + .verifyComplete(); // Verify the customizer was called verify(mockCustomizer).customize(any(), eq("GET"), - eq(UriComponentsBuilder.fromUriString(host).path("/sse").build().toUri()), isNull()); + eq(UriComponentsBuilder.fromUriString(host).path("/sse").build().toUri()), isNull(), eq(context)); clearInvocations(mockCustomizer); // Send test message @@ -445,12 +461,16 @@ void testAsyncRequestCustomizer() { Map.of("key", "value")); // Subscribe to messages and verify - StepVerifier.create(customizedTransport.sendMessage(testMessage)).verifyComplete(); + StepVerifier + .create(customizedTransport.sendMessage(testMessage) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, context))) + .verifyComplete(); // Verify the customizer was called var uriArgumentCaptor = ArgumentCaptor.forClass(URI.class); verify(mockCustomizer).customize(any(), eq("POST"), uriArgumentCaptor.capture(), eq( - "{\"jsonrpc\":\"2.0\",\"method\":\"test-method\",\"id\":\"test-id\",\"params\":{\"key\":\"value\"}}")); + "{\"jsonrpc\":\"2.0\",\"method\":\"test-method\",\"id\":\"test-id\",\"params\":{\"key\":\"value\"}}"), + eq(context)); assertThat(uriArgumentCaptor.getValue().toString()).startsWith(host + "/message?sessionId="); // Clean up diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportEmptyJsonResponseTest.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportEmptyJsonResponseTest.java similarity index 90% rename from mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportEmptyJsonResponseTest.java rename to mcp-core/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportEmptyJsonResponseTest.java index 8b3668671..81e642681 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportEmptyJsonResponseTest.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportEmptyJsonResponseTest.java @@ -22,6 +22,7 @@ import com.sun.net.httpserver.HttpServer; +import io.modelcontextprotocol.client.transport.customizer.McpSyncHttpClientRequestCustomizer; import io.modelcontextprotocol.server.transport.TomcatTestUtil; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.ProtocolVersions; @@ -70,14 +71,14 @@ static void stopContainer() { void testNotificationInitialized() throws URISyntaxException { var uri = new URI(host + "/mcp"); - var mockRequestCustomizer = mock(SyncHttpRequestCustomizer.class); + var mockRequestCustomizer = mock(McpSyncHttpClientRequestCustomizer.class); var transport = HttpClientStreamableHttpTransport.builder(host) .httpRequestCustomizer(mockRequestCustomizer) .build(); var initializeRequest = new McpSchema.InitializeRequest(ProtocolVersions.MCP_2025_03_26, McpSchema.ClientCapabilities.builder().roots(true).build(), - new McpSchema.Implementation("Spring AI MCP Client", "0.3.1")); + new McpSchema.Implementation("MCP Client", "0.3.1")); var testMessage = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, McpSchema.METHOD_INITIALIZE, "test-id", initializeRequest); @@ -85,7 +86,8 @@ void testNotificationInitialized() throws URISyntaxException { // Verify the customizer was called verify(mockRequestCustomizer, atLeastOnce()).customize(any(), eq("POST"), eq(uri), eq( - "{\"jsonrpc\":\"2.0\",\"method\":\"initialize\",\"id\":\"test-id\",\"params\":{\"protocolVersion\":\"2025-03-26\",\"capabilities\":{\"roots\":{\"listChanged\":true}},\"clientInfo\":{\"name\":\"Spring AI MCP Client\",\"version\":\"0.3.1\"}}}")); + "{\"jsonrpc\":\"2.0\",\"method\":\"initialize\",\"id\":\"test-id\",\"params\":{\"protocolVersion\":\"2025-03-26\",\"capabilities\":{\"roots\":{\"listChanged\":true}},\"clientInfo\":{\"name\":\"MCP Client\",\"version\":\"0.3.1\"}}}"), + any()); } diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportErrorHandlingTest.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportErrorHandlingTest.java similarity index 99% rename from mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportErrorHandlingTest.java rename to mcp-core/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportErrorHandlingTest.java index 2b502a83b..b82d6eb2c 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportErrorHandlingTest.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportErrorHandlingTest.java @@ -63,7 +63,7 @@ void startServer() throws IOException { if ("DELETE".equals(httpExchange.getRequestMethod())) { httpExchange.sendResponseHeaders(200, 0); } - else { + else if ("POST".equals(httpExchange.getRequestMethod())) { // Capture session ID from request if present String requestSessionId = httpExchange.getRequestHeaders().getFirst(HttpHeaders.MCP_SESSION_ID); lastReceivedSessionId.set(requestSessionId); diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportTest.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportTest.java similarity index 57% rename from mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportTest.java rename to mcp-core/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportTest.java index d645bb0b3..f9536b690 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportTest.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportTest.java @@ -4,9 +4,13 @@ package io.modelcontextprotocol.client.transport; +import io.modelcontextprotocol.client.transport.customizer.McpAsyncHttpClientRequestCustomizer; +import io.modelcontextprotocol.client.transport.customizer.McpSyncHttpClientRequestCustomizer; +import io.modelcontextprotocol.common.McpTransportContext; import io.modelcontextprotocol.spec.McpSchema; import java.net.URI; import java.net.URISyntaxException; +import java.util.Map; import java.util.function.Consumer; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; @@ -32,8 +36,11 @@ class HttpClientStreamableHttpTransportTest { static String host = "http://localhost:3001"; + private McpTransportContext context = McpTransportContext + .create(Map.of("test-transport-context-key", "some-value")); + @SuppressWarnings("resource") - static GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") + static GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v3") .withCommand("node dist/index.js streamableHttp") .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) .withExposedPorts(3001) @@ -63,7 +70,7 @@ void withTransport(HttpClientStreamableHttpTransport transport, Consumer ctx.put(McpTransportContext.KEY, context))) + .verifyComplete(); // Verify the customizer was called verify(mockRequestCustomizer, atLeastOnce()).customize(any(), eq("POST"), eq(uri), eq( - "{\"jsonrpc\":\"2.0\",\"method\":\"initialize\",\"id\":\"test-id\",\"params\":{\"protocolVersion\":\"2025-03-26\",\"capabilities\":{\"roots\":{\"listChanged\":true}},\"clientInfo\":{\"name\":\"Spring AI MCP Client\",\"version\":\"0.3.1\"}}}")); + "{\"jsonrpc\":\"2.0\",\"method\":\"initialize\",\"id\":\"test-id\",\"params\":{\"protocolVersion\":\"2025-06-18\",\"capabilities\":{\"roots\":{\"listChanged\":true}},\"clientInfo\":{\"name\":\"MCP Client\",\"version\":\"0.3.1\"}}}"), + eq(context)); }); } @Test void testAsyncRequestCustomizer() throws URISyntaxException { var uri = new URI(host + "/mcp"); - var mockRequestCustomizer = mock(AsyncHttpRequestCustomizer.class); - when(mockRequestCustomizer.customize(any(), any(), any(), any())) + var mockRequestCustomizer = mock(McpAsyncHttpClientRequestCustomizer.class); + when(mockRequestCustomizer.customize(any(), any(), any(), any(), any())) .thenAnswer(invocation -> Mono.just(invocation.getArguments()[0])); var transport = HttpClientStreamableHttpTransport.builder(host) @@ -100,16 +110,54 @@ void testAsyncRequestCustomizer() throws URISyntaxException { // Send test message var initializeRequest = new McpSchema.InitializeRequest(McpSchema.LATEST_PROTOCOL_VERSION, McpSchema.ClientCapabilities.builder().roots(true).build(), - new McpSchema.Implementation("Spring AI MCP Client", "0.3.1")); + new McpSchema.Implementation("MCP Client", "0.3.1")); var testMessage = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, McpSchema.METHOD_INITIALIZE, "test-id", initializeRequest); - StepVerifier.create(t.sendMessage(testMessage)).verifyComplete(); + StepVerifier + .create(t.sendMessage(testMessage).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, context))) + .verifyComplete(); // Verify the customizer was called verify(mockRequestCustomizer, atLeastOnce()).customize(any(), eq("POST"), eq(uri), eq( - "{\"jsonrpc\":\"2.0\",\"method\":\"initialize\",\"id\":\"test-id\",\"params\":{\"protocolVersion\":\"2025-03-26\",\"capabilities\":{\"roots\":{\"listChanged\":true}},\"clientInfo\":{\"name\":\"Spring AI MCP Client\",\"version\":\"0.3.1\"}}}")); + "{\"jsonrpc\":\"2.0\",\"method\":\"initialize\",\"id\":\"test-id\",\"params\":{\"protocolVersion\":\"2025-06-18\",\"capabilities\":{\"roots\":{\"listChanged\":true}},\"clientInfo\":{\"name\":\"MCP Client\",\"version\":\"0.3.1\"}}}"), + eq(context)); }); } + @Test + void testCloseUninitialized() { + var transport = HttpClientStreamableHttpTransport.builder(host).build(); + + StepVerifier.create(transport.closeGracefully()).verifyComplete(); + + var initializeRequest = new McpSchema.InitializeRequest(McpSchema.LATEST_PROTOCOL_VERSION, + McpSchema.ClientCapabilities.builder().roots(true).build(), + new McpSchema.Implementation("MCP Client", "0.3.1")); + var testMessage = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, McpSchema.METHOD_INITIALIZE, + "test-id", initializeRequest); + + StepVerifier.create(transport.sendMessage(testMessage)) + .expectErrorMessage("MCP session has been closed") + .verify(); + } + + @Test + void testCloseInitialized() { + var transport = HttpClientStreamableHttpTransport.builder(host).build(); + + var initializeRequest = new McpSchema.InitializeRequest(McpSchema.LATEST_PROTOCOL_VERSION, + McpSchema.ClientCapabilities.builder().roots(true).build(), + new McpSchema.Implementation("MCP Client", "0.3.1")); + var testMessage = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, McpSchema.METHOD_INITIALIZE, + "test-id", initializeRequest); + + StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); + StepVerifier.create(transport.closeGracefully()).verifyComplete(); + + StepVerifier.create(transport.sendMessage(testMessage)) + .expectErrorMatches(err -> err.getMessage().matches("MCP session with ID [a-zA-Z0-9-]* has been closed")) + .verify(); + } + } diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/customizer/DelegatingMcpAsyncHttpClientRequestCustomizerTest.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/customizer/DelegatingMcpAsyncHttpClientRequestCustomizerTest.java new file mode 100644 index 000000000..a04787aa3 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/customizer/DelegatingMcpAsyncHttpClientRequestCustomizerTest.java @@ -0,0 +1,72 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.client.transport.customizer; + +import java.net.URI; +import java.net.http.HttpRequest; +import java.util.List; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import io.modelcontextprotocol.common.McpTransportContext; + +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * Tests for {@link DelegatingMcpAsyncHttpClientRequestCustomizer}. + * + * @author Daniel Garnier-Moiroux + */ +class DelegatingMcpAsyncHttpClientRequestCustomizerTest { + + private static final URI TEST_URI = URI.create("https://example.com"); + + private final HttpRequest.Builder TEST_BUILDER = HttpRequest.newBuilder(TEST_URI); + + @Test + void delegates() { + var mockCustomizer = mock(McpAsyncHttpClientRequestCustomizer.class); + when(mockCustomizer.customize(any(), any(), any(), any(), any())) + .thenAnswer(invocation -> Mono.just(invocation.getArguments()[0])); + var customizer = new DelegatingMcpAsyncHttpClientRequestCustomizer(List.of(mockCustomizer)); + + var context = McpTransportContext.EMPTY; + StepVerifier + .create(customizer.customize(TEST_BUILDER, "GET", TEST_URI, "{\"everybody\": \"needs somebody\"}", context)) + .expectNext(TEST_BUILDER) + .verifyComplete(); + + verify(mockCustomizer).customize(TEST_BUILDER, "GET", TEST_URI, "{\"everybody\": \"needs somebody\"}", context); + } + + @Test + void delegatesInOrder() { + var customizer = new DelegatingMcpAsyncHttpClientRequestCustomizer( + List.of((builder, method, uri, body, ctx) -> Mono.just(builder.copy().header("x-test", "one")), + (builder, method, uri, body, ctx) -> Mono.just(builder.copy().header("x-test", "two")))); + + var headers = Mono + .from(customizer.customize(TEST_BUILDER, "GET", TEST_URI, "{\"everybody\": \"needs somebody\"}", + McpTransportContext.EMPTY)) + .map(HttpRequest.Builder::build) + .map(HttpRequest::headers) + .flatMapIterable(h -> h.allValues("x-test")); + + StepVerifier.create(headers).expectNext("one").expectNext("two").verifyComplete(); + } + + @Test + void constructorRequiresNonNull() { + assertThatThrownBy(() -> new DelegatingMcpAsyncHttpClientRequestCustomizer(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Customizers must not be null"); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/customizer/DelegatingMcpSyncHttpClientRequestCustomizerTest.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/customizer/DelegatingMcpSyncHttpClientRequestCustomizerTest.java new file mode 100644 index 000000000..6c51a3d12 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/customizer/DelegatingMcpSyncHttpClientRequestCustomizerTest.java @@ -0,0 +1,61 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.client.transport.customizer; + +import java.net.URI; +import java.net.http.HttpRequest; +import java.util.List; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; + +import io.modelcontextprotocol.common.McpTransportContext; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.verify; + +/** + * Tests for {@link DelegatingMcpSyncHttpClientRequestCustomizer}. + * + * @author Daniel Garnier-Moiroux + */ +class DelegatingMcpSyncHttpClientRequestCustomizerTest { + + private static final URI TEST_URI = URI.create("https://example.com"); + + private final HttpRequest.Builder TEST_BUILDER = HttpRequest.newBuilder(TEST_URI); + + @Test + void delegates() { + var mockCustomizer = Mockito.mock(McpSyncHttpClientRequestCustomizer.class); + var customizer = new DelegatingMcpSyncHttpClientRequestCustomizer(List.of(mockCustomizer)); + + var context = McpTransportContext.EMPTY; + customizer.customize(TEST_BUILDER, "GET", TEST_URI, "{\"everybody\": \"needs somebody\"}", context); + + verify(mockCustomizer).customize(TEST_BUILDER, "GET", TEST_URI, "{\"everybody\": \"needs somebody\"}", context); + } + + @Test + void delegatesInOrder() { + var testHeaderName = "x-test"; + var customizer = new DelegatingMcpSyncHttpClientRequestCustomizer( + List.of((builder, method, uri, body, ctx) -> builder.header(testHeaderName, "one"), + (builder, method, uri, body, ctx) -> builder.header(testHeaderName, "two"))); + + customizer.customize(TEST_BUILDER, "GET", TEST_URI, null, McpTransportContext.EMPTY); + var request = TEST_BUILDER.build(); + + assertThat(request.headers().allValues(testHeaderName)).containsExactly("one", "two"); + } + + @Test + void constructorRequiresNonNull() { + assertThatThrownBy(() -> new DelegatingMcpAsyncHttpClientRequestCustomizer(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Customizers must not be null"); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/common/AsyncServerMcpTransportContextIntegrationTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/common/AsyncServerMcpTransportContextIntegrationTests.java new file mode 100644 index 000000000..8b2dea462 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/common/AsyncServerMcpTransportContextIntegrationTests.java @@ -0,0 +1,280 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.common; + +import java.util.Map; +import java.util.function.BiFunction; + +import io.modelcontextprotocol.client.McpAsyncClient; +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.client.transport.customizer.McpAsyncHttpClientRequestCustomizer; +import io.modelcontextprotocol.client.transport.customizer.McpSyncHttpClientRequestCustomizer; +import io.modelcontextprotocol.server.McpAsyncServerExchange; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.McpServerFeatures; +import io.modelcontextprotocol.server.McpStatelessServerFeatures; +import io.modelcontextprotocol.server.McpTransportContextExtractor; +import io.modelcontextprotocol.server.transport.HttpServletSseServerTransportProvider; +import io.modelcontextprotocol.server.transport.HttpServletStatelessServerTransport; +import io.modelcontextprotocol.server.transport.HttpServletStreamableServerTransportProvider; +import io.modelcontextprotocol.server.transport.TomcatTestUtil; +import io.modelcontextprotocol.spec.McpSchema; +import jakarta.servlet.Servlet; +import jakarta.servlet.http.HttpServletRequest; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.apache.catalina.startup.Tomcat; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for {@link McpTransportContext} propagation between MCP clients and + * async servers. + * + *

+ * This test class validates the end-to-end flow of transport context propagation in MCP + * communication, demonstrating how contextual information can be passed from client to + * server through HTTP headers and accessed within server-side handlers. + * + *

Test Scenarios

+ *

+ * The tests cover multiple transport configurations with async servers: + *

    + *
  • Stateless server with async streamable HTTP clients
  • + *
  • Streamable server with async streamable HTTP clients
  • + *
  • SSE (Server-Sent Events) server with async SSE clients
  • + *
+ * + *

Context Propagation Flow

+ *
    + *
  1. Client-side: Context data is stored in the Reactor Context and injected into HTTP + * headers via {@link McpSyncHttpClientRequestCustomizer}
  2. + *
  3. Transport: The context travels as HTTP headers (specifically "x-test" header in + * these tests)
  4. + *
  5. Server-side: A {@link McpTransportContextExtractor} extracts the header value and + * makes it available to request handlers through {@link McpTransportContext}
  6. + *
  7. Verification: The server echoes back the received context value as the tool call + * result
  8. + *
+ * + *

+ * All tests use an embedded Tomcat server running on a dynamically allocated port to + * ensure isolation and prevent port conflicts during parallel test execution. + * + * @author Daniel Garnier-Moiroux + * @author Christian Tzolov + */ +@Timeout(15) +public class AsyncServerMcpTransportContextIntegrationTests { + + private static final int PORT = TomcatTestUtil.findAvailablePort(); + + private Tomcat tomcat; + + private static final String HEADER_NAME = "x-test"; + + private final McpAsyncHttpClientRequestCustomizer asyncClientRequestCustomizer = (builder, method, endpoint, body, + context) -> { + var headerValue = context.get("client-side-header-value"); + if (headerValue != null) { + builder.header(HEADER_NAME, headerValue.toString()); + } + return Mono.just(builder); + }; + + private final McpTransportContextExtractor serverContextExtractor = (HttpServletRequest r) -> { + var headerValue = r.getHeader(HEADER_NAME); + return headerValue != null ? McpTransportContext.create(Map.of("server-side-header-value", headerValue)) + : McpTransportContext.EMPTY; + }; + + private final HttpServletStatelessServerTransport statelessServerTransport = HttpServletStatelessServerTransport + .builder() + .contextExtractor(serverContextExtractor) + .build(); + + private final HttpServletStreamableServerTransportProvider streamableServerTransport = HttpServletStreamableServerTransportProvider + .builder() + .contextExtractor(serverContextExtractor) + .build(); + + private final HttpServletSseServerTransportProvider sseServerTransport = HttpServletSseServerTransportProvider + .builder() + .contextExtractor(serverContextExtractor) + .messageEndpoint("/message") + .build(); + + private final McpAsyncClient asyncStreamableClient = McpClient + .async(HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT) + .asyncHttpRequestCustomizer(asyncClientRequestCustomizer) + .build()) + .build(); + + private final McpAsyncClient asyncSseClient = McpClient + .async(HttpClientSseClientTransport.builder("http://localhost:" + PORT) + .asyncHttpRequestCustomizer(asyncClientRequestCustomizer) + .build()) + .build(); + + private final McpSchema.Tool tool = McpSchema.Tool.builder() + .name("test-tool") + .description("return the value of the x-test header from call tool request") + .build(); + + private final BiFunction> asyncStatelessHandler = ( + transportContext, request) -> { + return Mono + .just(new McpSchema.CallToolResult(transportContext.get("server-side-header-value").toString(), null)); + }; + + private final BiFunction> asyncStatefulHandler = ( + exchange, request) -> { + return asyncStatelessHandler.apply(exchange.transportContext(), request); + }; + + @AfterEach + public void after() { + if (statelessServerTransport != null) { + statelessServerTransport.closeGracefully().block(); + } + if (streamableServerTransport != null) { + streamableServerTransport.closeGracefully().block(); + } + if (sseServerTransport != null) { + sseServerTransport.closeGracefully().block(); + } + if (asyncStreamableClient != null) { + asyncStreamableClient.closeGracefully().block(); + } + if (asyncSseClient != null) { + asyncSseClient.closeGracefully().block(); + } + stopTomcat(); + } + + @Test + void asyncClinetStatelessServer() { + startTomcat(statelessServerTransport); + + var mcpServer = McpServer.async(statelessServerTransport) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpStatelessServerFeatures.AsyncToolSpecification(tool, asyncStatelessHandler)) + .build(); + + StepVerifier.create(asyncStreamableClient.initialize()).assertNext(initResult -> { + assertThat(initResult).isNotNull(); + }).verifyComplete(); + + // Test tool call with context + StepVerifier + .create(asyncStreamableClient.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, + McpTransportContext.create(Map.of("client-side-header-value", "some important value"))))) + .assertNext(response -> { + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + }) + .verifyComplete(); + + mcpServer.close(); + } + + @Test + void asyncClientStreamableServer() { + startTomcat(streamableServerTransport); + + var mcpServer = McpServer.async(streamableServerTransport) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpServerFeatures.AsyncToolSpecification(tool, null, asyncStatefulHandler)) + .build(); + + StepVerifier.create(asyncStreamableClient.initialize()).assertNext(initResult -> { + assertThat(initResult).isNotNull(); + }).verifyComplete(); + + // Test tool call with context + StepVerifier + .create(asyncStreamableClient.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, + McpTransportContext.create(Map.of("client-side-header-value", "some important value"))))) + .assertNext(response -> { + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + }) + .verifyComplete(); + + mcpServer.close(); + } + + @Test + void asyncClientSseServer() { + startTomcat(sseServerTransport); + + var mcpServer = McpServer.async(sseServerTransport) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpServerFeatures.AsyncToolSpecification(tool, null, asyncStatefulHandler)) + .build(); + + StepVerifier.create(asyncSseClient.initialize()).assertNext(initResult -> { + assertThat(initResult).isNotNull(); + }).verifyComplete(); + + // Test tool call with context + StepVerifier + .create(asyncSseClient.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, + McpTransportContext.create(Map.of("client-side-header-value", "some important value"))))) + .assertNext(response -> { + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + }) + .verifyComplete(); + + mcpServer.close(); + } + + private void startTomcat(Servlet transport) { + tomcat = TomcatTestUtil.createTomcatServer("", PORT, transport); + try { + tomcat.start(); + assertThat(tomcat.getServer().getState()).isEqualTo(LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + } + + private void stopTomcat() { + if (tomcat != null) { + try { + tomcat.stop(); + tomcat.destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/common/HttpClientStreamableHttpVersionNegotiationIntegrationTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/common/HttpClientStreamableHttpVersionNegotiationIntegrationTests.java new file mode 100644 index 000000000..8efb6a960 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/common/HttpClientStreamableHttpVersionNegotiationIntegrationTests.java @@ -0,0 +1,145 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package io.modelcontextprotocol.common; + +import java.util.List; +import java.util.Map; +import java.util.function.BiFunction; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.McpServerFeatures; +import io.modelcontextprotocol.server.McpSyncServer; +import io.modelcontextprotocol.server.McpSyncServerExchange; +import io.modelcontextprotocol.server.transport.HttpServletStreamableServerTransportProvider; +import io.modelcontextprotocol.server.transport.McpTestRequestRecordingServletFilter; +import io.modelcontextprotocol.server.transport.TomcatTestUtil; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.ProtocolVersions; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.apache.catalina.startup.Tomcat; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +class HttpClientStreamableHttpVersionNegotiationIntegrationTests { + + private Tomcat tomcat; + + private static final int PORT = TomcatTestUtil.findAvailablePort(); + + private final McpTestRequestRecordingServletFilter requestRecordingFilter = new McpTestRequestRecordingServletFilter(); + + private final HttpServletStreamableServerTransportProvider transport = HttpServletStreamableServerTransportProvider + .builder() + .contextExtractor( + req -> McpTransportContext.create(Map.of("protocol-version", req.getHeader("MCP-protocol-version")))) + .build(); + + private final McpSchema.Tool toolSpec = McpSchema.Tool.builder() + .name("test-tool") + .description("return the protocol version used") + .build(); + + private final BiFunction toolHandler = ( + exchange, request) -> new McpSchema.CallToolResult( + exchange.transportContext().get("protocol-version").toString(), null); + + McpSyncServer mcpServer = McpServer.sync(transport) + .capabilities(McpSchema.ServerCapabilities.builder().tools(false).build()) + .tools(new McpServerFeatures.SyncToolSpecification(toolSpec, null, toolHandler)) + .build(); + + @AfterEach + void tearDown() { + stopTomcat(); + } + + @Test + void usesLatestVersion() { + startTomcat(); + + var client = McpClient.sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT).build()) + .build(); + + client.initialize(); + McpSchema.CallToolResult response = client.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())); + + var calls = requestRecordingFilter.getCalls(); + + assertThat(calls).filteredOn(c -> !c.body().contains("\"method\":\"initialize\"")) + // GET /mcp ; POST notification/initialized ; POST tools/call + .hasSize(3) + .map(McpTestRequestRecordingServletFilter.Call::headers) + .allSatisfy(headers -> assertThat(headers).containsEntry("mcp-protocol-version", + ProtocolVersions.MCP_2025_06_18)); + + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo(ProtocolVersions.MCP_2025_06_18); + mcpServer.close(); + } + + @Test + void usesServerSupportedVersion() { + startTomcat(); + + var transport = HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT) + .supportedProtocolVersions(List.of(ProtocolVersions.MCP_2025_06_18, "2263-03-18")) + .build(); + var client = McpClient.sync(transport).build(); + + client.initialize(); + McpSchema.CallToolResult response = client.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())); + + var calls = requestRecordingFilter.getCalls(); + // Initialize tells the server the Client's latest supported version + // FIXME: Set the correct protocol version on GET /mcp + assertThat(calls).filteredOn(c -> c.method().equals("POST") && !c.body().contains("\"method\":\"initialize\"")) + // POST notification/initialized ; POST tools/call + .hasSize(2) + .map(McpTestRequestRecordingServletFilter.Call::headers) + .allSatisfy(headers -> assertThat(headers).containsEntry("mcp-protocol-version", + ProtocolVersions.MCP_2025_06_18)); + + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo(ProtocolVersions.MCP_2025_06_18); + mcpServer.close(); + } + + private void startTomcat() { + tomcat = TomcatTestUtil.createTomcatServer("", PORT, transport, requestRecordingFilter); + try { + tomcat.start(); + assertThat(tomcat.getServer().getState()).isEqualTo(LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + } + + private void stopTomcat() { + if (tomcat != null) { + try { + tomcat.stop(); + tomcat.destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/common/SyncServerMcpTransportContextIntegrationTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/common/SyncServerMcpTransportContextIntegrationTests.java new file mode 100644 index 000000000..cc8f4c4be --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/common/SyncServerMcpTransportContextIntegrationTests.java @@ -0,0 +1,243 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.common; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.McpClient.SyncSpec; +import io.modelcontextprotocol.client.McpSyncClient; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.client.transport.customizer.McpSyncHttpClientRequestCustomizer; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.McpServerFeatures; +import io.modelcontextprotocol.server.McpStatelessServerFeatures; +import io.modelcontextprotocol.server.McpSyncServerExchange; +import io.modelcontextprotocol.server.McpTransportContextExtractor; +import io.modelcontextprotocol.server.transport.HttpServletSseServerTransportProvider; +import io.modelcontextprotocol.server.transport.HttpServletStatelessServerTransport; +import io.modelcontextprotocol.server.transport.HttpServletStreamableServerTransportProvider; +import io.modelcontextprotocol.server.transport.TomcatTestUtil; +import io.modelcontextprotocol.spec.McpSchema; +import jakarta.servlet.Servlet; +import jakarta.servlet.http.HttpServletRequest; +import java.util.Map; +import java.util.function.BiFunction; +import java.util.function.Supplier; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.apache.catalina.startup.Tomcat; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Test both Client and Server {@link McpTransportContext} integration, in two steps. + *

+ * First, the client calls a tool and writes data stored in a thread-local to an HTTP + * header using {@link SyncSpec#transportContextProvider(Supplier)} and + * {@link McpSyncHttpClientRequestCustomizer}. + *

+ * Then the server reads the header with a {@link McpTransportContextExtractor} and + * returns the value as the result of the tool call. + * + * @author Daniel Garnier-Moiroux + */ +@Timeout(15) +public class SyncServerMcpTransportContextIntegrationTests { + + private static final int PORT = TomcatTestUtil.findAvailablePort(); + + private Tomcat tomcat; + + private static final ThreadLocal CLIENT_SIDE_HEADER_VALUE_HOLDER = new ThreadLocal<>(); + + private static final String HEADER_NAME = "x-test"; + + private final Supplier clientContextProvider = () -> { + var headerValue = CLIENT_SIDE_HEADER_VALUE_HOLDER.get(); + return headerValue != null ? McpTransportContext.create(Map.of("client-side-header-value", headerValue)) + : McpTransportContext.EMPTY; + }; + + private final McpSyncHttpClientRequestCustomizer clientRequestCustomizer = (builder, method, endpoint, body, + context) -> { + var headerValue = context.get("client-side-header-value"); + if (headerValue != null) { + builder.header(HEADER_NAME, headerValue.toString()); + } + }; + + private final McpTransportContextExtractor serverContextExtractor = (HttpServletRequest r) -> { + var headerValue = r.getHeader(HEADER_NAME); + return headerValue != null ? McpTransportContext.create(Map.of("server-side-header-value", headerValue)) + : McpTransportContext.EMPTY; + }; + + private final BiFunction statelessHandler = ( + transportContext, + request) -> new McpSchema.CallToolResult(transportContext.get("server-side-header-value").toString(), null); + + private final BiFunction statefulHandler = ( + exchange, request) -> statelessHandler.apply(exchange.transportContext(), request); + + private final HttpServletStatelessServerTransport statelessServerTransport = HttpServletStatelessServerTransport + .builder() + .contextExtractor(serverContextExtractor) + .build(); + + private final HttpServletStreamableServerTransportProvider streamableServerTransport = HttpServletStreamableServerTransportProvider + .builder() + .contextExtractor(serverContextExtractor) + .build(); + + private final HttpServletSseServerTransportProvider sseServerTransport = HttpServletSseServerTransportProvider + .builder() + .contextExtractor(serverContextExtractor) + .messageEndpoint("/message") + .build(); + + private final McpSyncClient streamableClient = McpClient + .sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT) + .httpRequestCustomizer(clientRequestCustomizer) + .build()) + .transportContextProvider(clientContextProvider) + .build(); + + private final McpSyncClient sseClient = McpClient + .sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT) + .httpRequestCustomizer(clientRequestCustomizer) + .build()) + .transportContextProvider(clientContextProvider) + .build(); + + private final McpSchema.Tool tool = McpSchema.Tool.builder() + .name("test-tool") + .description("return the value of the x-test header from call tool request") + .build(); + + @AfterEach + public void after() { + CLIENT_SIDE_HEADER_VALUE_HOLDER.remove(); + if (statelessServerTransport != null) { + statelessServerTransport.closeGracefully().block(); + } + if (streamableServerTransport != null) { + streamableServerTransport.closeGracefully().block(); + } + if (sseServerTransport != null) { + sseServerTransport.closeGracefully().block(); + } + if (streamableClient != null) { + streamableClient.closeGracefully(); + } + if (sseClient != null) { + sseClient.closeGracefully(); + } + stopTomcat(); + } + + @Test + void statelessServer() { + startTomcat(statelessServerTransport); + + var mcpServer = McpServer.sync(statelessServerTransport) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpStatelessServerFeatures.SyncToolSpecification(tool, statelessHandler)) + .build(); + + McpSchema.InitializeResult initResult = streamableClient.initialize(); + assertThat(initResult).isNotNull(); + + CLIENT_SIDE_HEADER_VALUE_HOLDER.set("some important value"); + McpSchema.CallToolResult response = streamableClient + .callTool(new McpSchema.CallToolRequest("test-tool", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + + mcpServer.close(); + } + + @Test + void streamableServer() { + startTomcat(streamableServerTransport); + + var mcpServer = McpServer.sync(streamableServerTransport) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpServerFeatures.SyncToolSpecification(tool, null, statefulHandler)) + .build(); + + McpSchema.InitializeResult initResult = streamableClient.initialize(); + assertThat(initResult).isNotNull(); + + CLIENT_SIDE_HEADER_VALUE_HOLDER.set("some important value"); + McpSchema.CallToolResult response = streamableClient + .callTool(new McpSchema.CallToolRequest("test-tool", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + + mcpServer.close(); + } + + @Test + void sseServer() { + startTomcat(sseServerTransport); + + var mcpServer = McpServer.sync(sseServerTransport) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpServerFeatures.SyncToolSpecification(tool, null, statefulHandler)) + .build(); + + McpSchema.InitializeResult initResult = sseClient.initialize(); + assertThat(initResult).isNotNull(); + + CLIENT_SIDE_HEADER_VALUE_HOLDER.set("some important value"); + McpSchema.CallToolResult response = sseClient.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + + mcpServer.close(); + } + + private void startTomcat(Servlet transport) { + tomcat = TomcatTestUtil.createTomcatServer("", PORT, transport); + try { + tomcat.start(); + assertThat(tomcat.getServer().getState()).isEqualTo(LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + } + + private void stopTomcat() { + if (tomcat != null) { + try { + tomcat.stop(); + tomcat.destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java similarity index 56% rename from mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java index 0ba8bf929..090710248 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java @@ -4,18 +4,9 @@ package io.modelcontextprotocol.server; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; -import static org.assertj.core.api.Assertions.assertThatThrownBy; - import java.time.Duration; import java.util.List; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; - -import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; @@ -26,9 +17,17 @@ import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; import io.modelcontextprotocol.spec.McpSchema.Tool; import io.modelcontextprotocol.spec.McpServerTransportProvider; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; +import static io.modelcontextprotocol.util.ToolsUtils.EMPTY_JSON_SCHEMA; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + /** * Test suite for the {@link McpAsyncServer} that can be used with different * {@link io.modelcontextprotocol.spec.McpServerTransportProvider} implementations. @@ -86,30 +85,27 @@ void testGracefulShutdown() { void testImmediateClose() { var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - assertThatCode(() -> mcpAsyncServer.close()).doesNotThrowAnyException(); + assertThatCode(mcpAsyncServer::close).doesNotThrowAnyException(); } // --------------------------------------- // Tools Tests // --------------------------------------- - String emptyJsonSchema = """ - { - "$schema": "http://json-schema.org/draft-07/schema#", - "type": "object", - "properties": {} - } - """; - @Test @Deprecated void testAddTool() { - Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema); + Tool newTool = McpSchema.Tool.builder() + .name("new-tool") + .title("New test tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); - StepVerifier.create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolSpecification(newTool, - (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))))) + StepVerifier + .create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolSpecification(newTool, + (exchange, args) -> Mono.just(CallToolResult.builder().content(List.of()).isError(false).build())))) .verifyComplete(); assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); @@ -117,14 +113,20 @@ void testAddTool() { @Test void testAddToolCall() { - Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema); + Tool newTool = McpSchema.Tool.builder() + .name("new-tool") + .title("New test tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); StepVerifier.create(mcpAsyncServer.addTool(McpServerFeatures.AsyncToolSpecification.builder() .tool(newTool) - .callHandler((exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) + .callHandler((exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) .build())).verifyComplete(); assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); @@ -133,68 +135,88 @@ void testAddToolCall() { @Test @Deprecated void testAddDuplicateTool() { - Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); + Tool duplicateTool = McpSchema.Tool.builder() + .name(TEST_TOOL_NAME) + .title("Duplicate tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(duplicateTool, (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))) + .tool(duplicateTool, + (exchange, args) -> Mono.just(CallToolResult.builder().content(List.of()).isError(false).build())) .build(); StepVerifier .create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolSpecification(duplicateTool, - (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))))) - .verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Tool with name '" + TEST_TOOL_NAME + "' already exists"); - }); + (exchange, args) -> Mono.just(CallToolResult.builder().content(List.of()).isError(false).build())))) + .verifyComplete(); assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); } @Test void testAddDuplicateToolCall() { - Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); + Tool duplicateTool = McpSchema.Tool.builder() + .name(TEST_TOOL_NAME) + .title("Duplicate tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .toolCall(duplicateTool, (exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) + .toolCall(duplicateTool, + (exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) .build(); StepVerifier.create(mcpAsyncServer.addTool(McpServerFeatures.AsyncToolSpecification.builder() .tool(duplicateTool) - .callHandler((exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) - .build())).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Tool with name '" + TEST_TOOL_NAME + "' already exists"); - }); + .callHandler((exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) + .build())).verifyComplete(); assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); } @Test void testDuplicateToolCallDuringBuilding() { - Tool duplicateTool = new Tool("duplicate-build-toolcall", "Duplicate toolcall during building", - emptyJsonSchema); + Tool duplicateTool = McpSchema.Tool.builder() + .name("duplicate-build-toolcall") + .title("Duplicate toolcall during building") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); assertThatThrownBy(() -> prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .toolCall(duplicateTool, (exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) - .toolCall(duplicateTool, (exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) // Duplicate! + .toolCall(duplicateTool, + (exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) + .toolCall(duplicateTool, + (exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) // Duplicate! .build()).isInstanceOf(IllegalArgumentException.class) .hasMessage("Tool with name 'duplicate-build-toolcall' is already registered."); } @Test void testDuplicateToolsInBatchListRegistration() { - Tool duplicateTool = new Tool("batch-list-tool", "Duplicate tool in batch list", emptyJsonSchema); + Tool duplicateTool = McpSchema.Tool.builder() + .name("batch-list-tool") + .title("Duplicate tool in batch list") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + List specs = List.of( McpServerFeatures.AsyncToolSpecification.builder() .tool(duplicateTool) - .callHandler((exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) + .callHandler((exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) .build(), McpServerFeatures.AsyncToolSpecification.builder() .tool(duplicateTool) - .callHandler((exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) + .callHandler((exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) .build() // Duplicate! ); @@ -207,17 +229,23 @@ void testDuplicateToolsInBatchListRegistration() { @Test void testDuplicateToolsInBatchVarargsRegistration() { - Tool duplicateTool = new Tool("batch-varargs-tool", "Duplicate tool in batch varargs", emptyJsonSchema); + Tool duplicateTool = McpSchema.Tool.builder() + .name("batch-varargs-tool") + .title("Duplicate tool in batch varargs") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); assertThatThrownBy(() -> prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .tools(McpServerFeatures.AsyncToolSpecification.builder() .tool(duplicateTool) - .callHandler((exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) + .callHandler((exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) .build(), McpServerFeatures.AsyncToolSpecification.builder() .tool(duplicateTool) - .callHandler((exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) + .callHandler((exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) .build() // Duplicate! ) .build()).isInstanceOf(IllegalArgumentException.class) @@ -226,11 +254,17 @@ void testDuplicateToolsInBatchVarargsRegistration() { @Test void testRemoveTool() { - Tool too = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); + Tool too = McpSchema.Tool.builder() + .name(TEST_TOOL_NAME) + .title("Duplicate tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .toolCall(too, (exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) + .toolCall(too, + (exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) .build(); StepVerifier.create(mcpAsyncServer.removeTool(TEST_TOOL_NAME)).verifyComplete(); @@ -244,20 +278,23 @@ void testRemoveNonexistentTool() { .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); - StepVerifier.create(mcpAsyncServer.removeTool("nonexistent-tool")).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class).hasMessage("Tool with name 'nonexistent-tool' not found"); - }); + StepVerifier.create(mcpAsyncServer.removeTool("nonexistent-tool")).verifyComplete(); assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); } @Test void testNotifyToolsListChanged() { - Tool too = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); + Tool too = McpSchema.Tool.builder() + .name(TEST_TOOL_NAME) + .title("Duplicate tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .toolCall(too, (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))) + .toolCall(too, + (exchange, args) -> Mono.just(CallToolResult.builder().content(List.of()).isError(false).build())) .build(); StepVerifier.create(mcpAsyncServer.notifyToolsListChanged()).verifyComplete(); @@ -296,8 +333,13 @@ void testAddResource() { .capabilities(ServerCapabilities.builder().resources(true, false).build()) .build(); - Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", - null); + Resource resource = Resource.builder() + .uri(TEST_RESOURCE_URI) + .name("Test Resource") + .title("Test Resource") + .mimeType("text/plain") + .description("Test resource description") + .build(); McpServerFeatures.AsyncResourceSpecification specification = new McpServerFeatures.AsyncResourceSpecification( resource, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); @@ -314,7 +356,7 @@ void testAddResourceWithNullSpecification() { StepVerifier.create(mcpAsyncServer.addResource((McpServerFeatures.AsyncResourceSpecification) null)) .verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class).hasMessage("Resource must not be null"); + assertThat(error).isInstanceOf(IllegalArgumentException.class).hasMessage("Resource must not be null"); }); assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); @@ -325,14 +367,19 @@ void testAddResourceWithoutCapability() { // Create a server without resource capabilities McpAsyncServer serverWithoutResources = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", - null); + Resource resource = Resource.builder() + .uri(TEST_RESOURCE_URI) + .name("Test Resource") + .title("Test Resource") + .mimeType("text/plain") + .description("Test resource description") + .build(); McpServerFeatures.AsyncResourceSpecification specification = new McpServerFeatures.AsyncResourceSpecification( resource, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); StepVerifier.create(serverWithoutResources.addResource(specification)).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Server must be configured with resource capabilities"); + assertThat(error).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Server must be configured with resource capabilities"); }); } @@ -342,11 +389,191 @@ void testRemoveResourceWithoutCapability() { McpAsyncServer serverWithoutResources = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); StepVerifier.create(serverWithoutResources.removeResource(TEST_RESOURCE_URI)).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Server must be configured with resource capabilities"); + assertThat(error).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Server must be configured with resource capabilities"); }); } + @Test + void testListResources() { + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + Resource resource = Resource.builder() + .uri(TEST_RESOURCE_URI) + .name("Test Resource") + .title("Test Resource") + .mimeType("text/plain") + .description("Test resource description") + .build(); + McpServerFeatures.AsyncResourceSpecification specification = new McpServerFeatures.AsyncResourceSpecification( + resource, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + StepVerifier + .create(mcpAsyncServer.addResource(specification).then(mcpAsyncServer.listResources().collectList())) + .expectNextMatches(resources -> resources.size() == 1 && resources.get(0).uri().equals(TEST_RESOURCE_URI)) + .verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testRemoveResource() { + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + Resource resource = Resource.builder() + .uri(TEST_RESOURCE_URI) + .name("Test Resource") + .title("Test Resource") + .mimeType("text/plain") + .description("Test resource description") + .build(); + McpServerFeatures.AsyncResourceSpecification specification = new McpServerFeatures.AsyncResourceSpecification( + resource, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + StepVerifier + .create(mcpAsyncServer.addResource(specification).then(mcpAsyncServer.removeResource(TEST_RESOURCE_URI))) + .verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testRemoveNonexistentResource() { + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + // Removing a non-existent resource should complete successfully (no error) + // as per the new implementation that just logs a warning + StepVerifier.create(mcpAsyncServer.removeResource("nonexistent://resource")).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Resource Template Tests + // --------------------------------------- + + @Test + void testAddResourceTemplate() { + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + McpSchema.ResourceTemplate template = McpSchema.ResourceTemplate.builder() + .uriTemplate("test://template/{id}") + .name("test-template") + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.AsyncResourceTemplateSpecification specification = new McpServerFeatures.AsyncResourceTemplateSpecification( + template, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + StepVerifier.create(mcpAsyncServer.addResourceTemplate(specification)).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testAddResourceTemplateWithoutCapability() { + // Create a server without resource capabilities + McpAsyncServer serverWithoutResources = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + McpSchema.ResourceTemplate template = McpSchema.ResourceTemplate.builder() + .uriTemplate("test://template/{id}") + .name("test-template") + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.AsyncResourceTemplateSpecification specification = new McpServerFeatures.AsyncResourceTemplateSpecification( + template, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + StepVerifier.create(serverWithoutResources.addResourceTemplate(specification)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Server must be configured with resource capabilities"); + }); + } + + @Test + void testRemoveResourceTemplate() { + McpSchema.ResourceTemplate template = McpSchema.ResourceTemplate.builder() + .uriTemplate("test://template/{id}") + .name("test-template") + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.AsyncResourceTemplateSpecification specification = new McpServerFeatures.AsyncResourceTemplateSpecification( + template, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .resourceTemplates(specification) + .build(); + + StepVerifier.create(mcpAsyncServer.removeResourceTemplate("test://template/{id}")).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testRemoveResourceTemplateWithoutCapability() { + // Create a server without resource capabilities + McpAsyncServer serverWithoutResources = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + StepVerifier.create(serverWithoutResources.removeResourceTemplate("test://template/{id}")) + .verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Server must be configured with resource capabilities"); + }); + } + + @Test + void testRemoveNonexistentResourceTemplate() { + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + StepVerifier.create(mcpAsyncServer.removeResourceTemplate("nonexistent://template/{id}")).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testListResourceTemplates() { + McpSchema.ResourceTemplate template = McpSchema.ResourceTemplate.builder() + .uriTemplate("test://template/{id}") + .name("test-template") + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.AsyncResourceTemplateSpecification specification = new McpServerFeatures.AsyncResourceTemplateSpecification( + template, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .resourceTemplates(specification) + .build(); + + // Note: Based on the current implementation, listResourceTemplates() returns + // Flux + // This appears to be a bug in the implementation that should return + // Flux + StepVerifier.create(mcpAsyncServer.listResourceTemplates().collectList()) + .expectNextMatches(resources -> resources.size() >= 0) // Just verify it + // doesn't error + .verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + // --------------------------------------- // Prompts Tests // --------------------------------------- @@ -368,7 +595,8 @@ void testAddPromptWithNullSpecification() { StepVerifier.create(mcpAsyncServer.addPrompt((McpServerFeatures.AsyncPromptSpecification) null)) .verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class).hasMessage("Prompt specification must not be null"); + assertThat(error).isInstanceOf(IllegalArgumentException.class) + .hasMessage("Prompt specification must not be null"); }); } @@ -383,7 +611,7 @@ void testAddPromptWithoutCapability() { .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content")))))); StepVerifier.create(serverWithoutPrompts.addPrompt(specification)).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) + assertThat(error).isInstanceOf(IllegalStateException.class) .hasMessage("Server must be configured with prompt capabilities"); }); } @@ -394,7 +622,7 @@ void testRemovePromptWithoutCapability() { McpAsyncServer serverWithoutPrompts = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); StepVerifier.create(serverWithoutPrompts.removePrompt(TEST_PROMPT_NAME)).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) + assertThat(error).isInstanceOf(IllegalStateException.class) .hasMessage("Server must be configured with prompt capabilities"); }); } @@ -424,10 +652,7 @@ void testRemoveNonexistentPrompt() { .capabilities(ServerCapabilities.builder().prompts(true).build()) .build(); - StepVerifier.create(mcpAsyncServer2.removePrompt("nonexistent-prompt")).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Prompt with name 'nonexistent-prompt' not found"); - }); + StepVerifier.create(mcpAsyncServer2.removePrompt("nonexistent-prompt")).verifyComplete(); assertThatCode(() -> mcpAsyncServer2.closeGracefully().block(Duration.ofSeconds(10))) .doesNotThrowAnyException(); diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpClientServerIntegrationTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/AbstractMcpClientServerIntegrationTests.java similarity index 78% rename from mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpClientServerIntegrationTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/server/AbstractMcpClientServerIntegrationTests.java index e2adb340c..1f5387f37 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpClientServerIntegrationTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/AbstractMcpClientServerIntegrationTests.java @@ -4,14 +4,6 @@ package io.modelcontextprotocol.server; -import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson; -import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatExceptionOfType; -import static org.assertj.core.api.Assertions.assertWith; -import static org.awaitility.Awaitility.await; -import static org.mockito.Mockito.mock; - import java.net.URI; import java.net.http.HttpClient; import java.net.http.HttpRequest; @@ -23,15 +15,14 @@ import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiFunction; import java.util.function.Function; import java.util.stream.Collectors; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; - import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.common.McpTransportContext; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; @@ -50,11 +41,25 @@ import io.modelcontextprotocol.spec.McpSchema.Role; import io.modelcontextprotocol.spec.McpSchema.Root; import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.TextContent; import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.util.Utils; import net.javacrumbs.jsonunit.core.Option; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; +import static io.modelcontextprotocol.util.ToolsUtils.EMPTY_JSON_SCHEMA; +import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson; +import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertWith; +import static org.awaitility.Awaitility.await; +import static org.mockito.Mockito.mock; + public abstract class AbstractMcpClientServerIntegrationTests { protected ConcurrentHashMap clientBuilders = new ConcurrentHashMap<>(); @@ -66,7 +71,7 @@ public abstract class AbstractMcpClientServerIntegrationTests { abstract protected McpServer.SyncSpecification prepareSyncServerBuilder(); @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) + @MethodSource("clientsForTesting") void simple(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -74,7 +79,6 @@ void simple(String clientType) { var server = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .requestTimeout(Duration.ofSeconds(1000)) .build(); - try ( // Create client without sampling capabilities var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) @@ -84,23 +88,25 @@ void simple(String clientType) { assertThat(client.initialize()).isNotNull(); } - server.closeGracefully(); + finally { + server.closeGracefully().block(); + } } // --------------------------------------- // Sampling Tests // --------------------------------------- @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) + @MethodSource("clientsForTesting") void testCreateMessageWithoutSamplingCapabilities(String clientType) { var clientBuilder = clientBuilders.get(clientType); McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) .callHandler((exchange, request) -> { - exchange.createMessage(mock(McpSchema.CreateMessageRequest.class)).block(); - return Mono.just(mock(CallToolResult.class)); + return exchange.createMessage(mock(McpSchema.CreateMessageRequest.class)) + .then(Mono.just(mock(CallToolResult.class))); }) .build(); @@ -121,11 +127,13 @@ void testCreateMessageWithoutSamplingCapabilities(String clientType) { .hasMessage("Client must be configured with sampling capabilities"); } } - server.closeGracefully(); + finally { + server.closeGracefully().block(); + } } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) + @MethodSource("clientsForTesting") void testCreateMessageSuccess(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -138,13 +146,14 @@ void testCreateMessageSuccess(String clientType) { CreateMessageResult.StopReason.STOP_SEQUENCE); }; - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); + CallToolResult callResponse = McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent("CALL RESPONSE")) + .build(); AtomicReference samplingResult = new AtomicReference<>(); McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) .callHandler((exchange, request) -> { var createMessageRequest = McpSchema.CreateMessageRequest.builder() @@ -188,11 +197,13 @@ void testCreateMessageSuccess(String clientType) { assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); }); } - mcpServer.close(); + finally { + mcpServer.closeGracefully().block(); + } } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) + @MethodSource("clientsForTesting") void testCreateMessageWithRequestTimeoutSuccess(String clientType) throws InterruptedException { // Client @@ -212,20 +223,16 @@ void testCreateMessageWithRequestTimeoutSuccess(String clientType) throws Interr CreateMessageResult.StopReason.STOP_SEQUENCE); }; - var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().sampling().build()) - .sampling(samplingHandler) - .build(); - // Server - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); + CallToolResult callResponse = McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent("CALL RESPONSE")) + .build(); AtomicReference samplingResult = new AtomicReference<>(); McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) .callHandler((exchange, request) -> { var createMessageRequest = McpSchema.CreateMessageRequest.builder() @@ -249,30 +256,35 @@ void testCreateMessageWithRequestTimeoutSuccess(String clientType) throws Interr .requestTimeout(Duration.ofSeconds(4)) .tools(tool) .build(); + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); - assertWith(samplingResult.get(), result -> { - assertThat(result).isNotNull(); - assertThat(result.role()).isEqualTo(Role.USER); - assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); - assertThat(result.model()).isEqualTo("MockModelName"); - assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); - }); - - mcpClient.close(); - mcpServer.close(); + assertWith(samplingResult.get(), result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }); + } + finally { + mcpServer.closeGracefully().block(); + } } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) + @MethodSource("clientsForTesting") void testCreateMessageWithRequestTimeoutFail(String clientType) throws InterruptedException { var clientBuilder = clientBuilders.get(clientType); @@ -290,16 +302,12 @@ void testCreateMessageWithRequestTimeoutFail(String clientType) throws Interrupt CreateMessageResult.StopReason.STOP_SEQUENCE); }; - var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().sampling().build()) - .sampling(samplingHandler) + CallToolResult callResponse = McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent("CALL RESPONSE")) .build(); - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) .callHandler((exchange, request) -> { var createMessageRequest = McpSchema.CreateMessageRequest.builder() @@ -322,28 +330,34 @@ void testCreateMessageWithRequestTimeoutFail(String clientType) throws Interrupt .tools(tool) .build(); - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build()) { - assertThatExceptionOfType(McpError.class).isThrownBy(() -> { - mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - }).withMessageContaining("1000ms"); + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); - mcpClient.close(); - mcpServer.close(); + assertThatExceptionOfType(McpError.class).isThrownBy(() -> { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + }).withMessageContaining("1000ms"); + } + finally { + mcpServer.closeGracefully().block(); + } } // --------------------------------------- // Elicitation Tests // --------------------------------------- @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) + @MethodSource("clientsForTesting") void testCreateElicitationWithoutElicitationCapabilities(String clientType) { var clientBuilder = clientBuilders.get(clientType); McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) .callHandler((exchange, request) -> exchange.createElicitation(mock(ElicitRequest.class)) .then(Mono.just(mock(CallToolResult.class)))) .build(); @@ -363,11 +377,13 @@ void testCreateElicitationWithoutElicitationCapabilities(String clientType) { .hasMessage("Client must be configured with elicitation capabilities"); } } - server.closeGracefully().block(); + finally { + server.closeGracefully().block(); + } } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) + @MethodSource("clientsForTesting") void testCreateElicitationSuccess(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -380,11 +396,12 @@ void testCreateElicitationSuccess(String clientType) { Map.of("message", request.message())); }; - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); + CallToolResult callResponse = McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent("CALL RESPONSE")) + .build(); McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) .callHandler((exchange, request) -> { var elicitationRequest = McpSchema.ElicitRequest.builder() @@ -418,11 +435,13 @@ void testCreateElicitationSuccess(String clientType) { assertThat(response).isNotNull(); assertThat(response).isEqualTo(callResponse); } - mcpServer.closeGracefully().block(); + finally { + mcpServer.closeGracefully().block(); + } } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) + @MethodSource("clientsForTesting") void testCreateElicitationWithRequestTimeoutSuccess(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -433,18 +452,14 @@ void testCreateElicitationWithRequestTimeoutSuccess(String clientType) { return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); }; - var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().elicitation().build()) - .elicitation(elicitationHandler) + CallToolResult callResponse = McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent("CALL RESPONSE")) .build(); - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); - AtomicReference resultRef = new AtomicReference<>(); McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) .callHandler((exchange, request) -> { var elicitationRequest = McpSchema.ElicitRequest.builder() @@ -464,25 +479,31 @@ void testCreateElicitationWithRequestTimeoutSuccess(String clientType) { .tools(tool) .build(); - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build()) { - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); - assertWith(resultRef.get(), result -> { - assertThat(result).isNotNull(); - assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); - assertThat(result.content().get("message")).isEqualTo("Test message"); - }); + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - mcpClient.closeGracefully(); - mcpServer.closeGracefully().block(); + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + assertWith(resultRef.get(), result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }); + } + finally { + mcpServer.closeGracefully().block(); + } } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) + @MethodSource("clientsForTesting") void testCreateElicitationWithRequestTimeoutFail(String clientType) { var latch = new CountDownLatch(1); @@ -504,17 +525,12 @@ void testCreateElicitationWithRequestTimeoutFail(String clientType) { return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); }; - var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().elicitation().build()) - .elicitation(elicitationHandler) - .build(); - - CallToolResult callResponse = new CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + CallToolResult callResponse = CallToolResult.builder().addContent(new TextContent("CALL RESPONSE")).build(); AtomicReference resultRef = new AtomicReference<>(); McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) .callHandler((exchange, request) -> { var elicitationRequest = ElicitRequest.builder() @@ -534,25 +550,31 @@ void testCreateElicitationWithRequestTimeoutFail(String clientType) { .tools(tool) .build(); - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build()) { - assertThatExceptionOfType(McpError.class).isThrownBy(() -> { - mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - }).withMessageContaining("within 1000ms"); + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); - ElicitResult elicitResult = resultRef.get(); - assertThat(elicitResult).isNull(); + assertThatExceptionOfType(McpError.class).isThrownBy(() -> { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + }).withMessageContaining("within 1000ms"); - mcpClient.closeGracefully(); - mcpServer.closeGracefully().block(); + ElicitResult elicitResult = resultRef.get(); + assertThat(elicitResult).isNull(); + } + finally { + mcpServer.closeGracefully().block(); + } } // --------------------------------------- // Roots Tests // --------------------------------------- @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) + @MethodSource("clientsForTesting") void testRootsSuccess(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -594,18 +616,19 @@ void testRootsSuccess(String clientType) { assertThat(rootsRef.get()).containsAll(List.of(roots.get(1), root3)); }); } - - mcpServer.close(); + finally { + mcpServer.closeGracefully(); + } } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) + @MethodSource("clientsForTesting") void testRootsWithoutCapability(String clientType) { var clientBuilder = clientBuilders.get(clientType); McpServerFeatures.SyncToolSpecification tool = McpServerFeatures.SyncToolSpecification.builder() - .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) .callHandler((exchange, request) -> { exchange.listRoots(); // try to list roots @@ -632,12 +655,13 @@ void testRootsWithoutCapability(String clientType) { assertThat(e).isInstanceOf(McpError.class).hasMessage("Roots not supported"); } } - - mcpServer.close(); + finally { + mcpServer.closeGracefully(); + } } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) + @MethodSource("clientsForTesting") void testRootsNotificationWithEmptyRootsList(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -661,12 +685,13 @@ void testRootsNotificationWithEmptyRootsList(String clientType) { assertThat(rootsRef.get()).isEmpty(); }); } - - mcpServer.close(); + finally { + mcpServer.closeGracefully(); + } } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) + @MethodSource("clientsForTesting") void testRootsWithMultipleHandlers(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -694,12 +719,13 @@ void testRootsWithMultipleHandlers(String clientType) { assertThat(rootsRef2.get()).containsAll(roots); }); } - - mcpServer.close(); + finally { + mcpServer.closeGracefully(); + } } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) + @MethodSource("clientsForTesting") void testRootsServerCloseWithActiveSubscription(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -725,30 +751,26 @@ void testRootsServerCloseWithActiveSubscription(String clientType) { assertThat(rootsRef.get()).containsAll(roots); }); } - - mcpServer.close(); + finally { + mcpServer.closeGracefully(); + } } // --------------------------------------- // Tools Tests // --------------------------------------- - String emptyJsonSchema = """ - { - "$schema": "http://json-schema.org/draft-07/schema#", - "type": "object", - "properties": {} - } - """; - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) + @MethodSource("clientsForTesting") void testToolCallSuccess(String clientType) { var clientBuilder = clientBuilders.get(clientType); - var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + var responseBodyIsNullOrBlank = new AtomicBoolean(false); + var callResponse = McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent("CALL RESPONSE; ctx=importantValue")) + .build(); McpServerFeatures.SyncToolSpecification tool1 = McpServerFeatures.SyncToolSpecification.builder() - .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) .callHandler((exchange, request) -> { try { @@ -759,7 +781,7 @@ void testToolCallSuccess(String clientType) { .GET() .build(), HttpResponse.BodyHandlers.ofString()); String responseBody = response.body(); - assertThat(responseBody).isNotBlank(); + responseBodyIsNullOrBlank.set(!Utils.hasText(responseBody)); } catch (Exception e) { e.printStackTrace(); @@ -782,14 +804,16 @@ void testToolCallSuccess(String clientType) { CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + assertThat(responseBodyIsNullOrBlank.get()).isFalse(); assertThat(response).isNotNull().isEqualTo(callResponse); } - - mcpServer.close(); + finally { + mcpServer.closeGracefully(); + } } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) + @MethodSource("clientsForTesting") void testThrowingToolCallIsCaughtBeforeTimeout(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -800,7 +824,7 @@ void testThrowingToolCallIsCaughtBeforeTimeout(String clientType) { .tool(Tool.builder() .name("tool1") .description("tool1 description") - .inputSchema(emptyJsonSchema) + .inputSchema(EMPTY_JSON_SCHEMA) .build()) .callHandler((exchange, request) -> { // We trigger a timeout on blocking read, raising an exception @@ -815,25 +839,87 @@ void testThrowingToolCallIsCaughtBeforeTimeout(String clientType) { assertThat(initResult).isNotNull(); // We expect the tool call to fail immediately with the exception raised by - // the offending tool - // instead of getting back a timeout. + // the offending tool instead of getting back a timeout. assertThatExceptionOfType(McpError.class) .isThrownBy(() -> mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of()))) .withMessageContaining("Timeout on blocking read"); } + finally { + mcpServer.closeGracefully(); + } + } - mcpServer.close(); + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testToolCallSuccessWithTransportContextExtraction(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var transportContextIsNull = new AtomicBoolean(false); + var transportContextIsEmpty = new AtomicBoolean(false); + var responseBodyIsNullOrBlank = new AtomicBoolean(false); + + var expectedCallResponse = McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent("CALL RESPONSE; ctx=value")) + .build(); + McpServerFeatures.SyncToolSpecification tool1 = McpServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) + .callHandler((exchange, request) -> { + + McpTransportContext transportContext = exchange.transportContext(); + transportContextIsNull.set(transportContext == null); + transportContextIsEmpty.set(transportContext.equals(McpTransportContext.EMPTY)); + String ctxValue = (String) transportContext.get("important"); + + try { + String responseBody = "TOOL RESPONSE"; + responseBodyIsNullOrBlank.set(!Utils.hasText(responseBody)); + } + catch (Exception e) { + e.printStackTrace(); + } + + return McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent("CALL RESPONSE; ctx=" + ctxValue)) + .build(); + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(transportContextIsNull.get()).isFalse(); + assertThat(transportContextIsEmpty.get()).isFalse(); + assertThat(responseBodyIsNullOrBlank.get()).isFalse(); + assertThat(response).isNotNull().isEqualTo(expectedCallResponse); + } + finally { + mcpServer.closeGracefully(); + } } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) + @MethodSource("clientsForTesting") void testToolListChangeHandlingSuccess(String clientType) { var clientBuilder = clientBuilders.get(clientType); - var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + var callResponse = McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent("CALL RESPONSE")) + .build(); + McpServerFeatures.SyncToolSpecification tool1 = McpServerFeatures.SyncToolSpecification.builder() - .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) .callHandler((exchange, request) -> { // perform a blocking call to a remote service try { @@ -902,7 +988,7 @@ void testToolListChangeHandlingSuccess(String clientType) { .tool(Tool.builder() .name("tool2") .description("tool2 description") - .inputSchema(emptyJsonSchema) + .inputSchema(EMPTY_JSON_SCHEMA) .build()) .callHandler((exchange, request) -> callResponse) .build(); @@ -913,12 +999,13 @@ void testToolListChangeHandlingSuccess(String clientType) { assertThat(toolsRef.get()).containsAll(List.of(tool2.tool())); }); } - - mcpServer.close(); + finally { + mcpServer.closeGracefully(); + } } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) + @MethodSource("clientsForTesting") void testInitialize(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -930,15 +1017,16 @@ void testInitialize(String clientType) { InitializeResult initResult = mcpClient.initialize(); assertThat(initResult).isNotNull(); } - - mcpServer.close(); + finally { + mcpServer.closeGracefully(); + } } // --------------------------------------- // Logging Tests // --------------------------------------- @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) + @MethodSource("clientsForTesting") void testLoggingNotification(String clientType) throws InterruptedException { int expectedNotificationsCount = 3; CountDownLatch latch = new CountDownLatch(expectedNotificationsCount); @@ -946,13 +1034,13 @@ void testLoggingNotification(String clientType) throws InterruptedException { List receivedNotifications = new CopyOnWriteArrayList<>(); var clientBuilder = clientBuilders.get(clientType); - ; + // Create server with a tool that sends logging notifications McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() .tool(Tool.builder() .name("logging-test") .description("Test logging notifications") - .inputSchema(emptyJsonSchema) + .inputSchema(EMPTY_JSON_SCHEMA) .build()) .callHandler((exchange, request) -> { @@ -989,13 +1077,16 @@ void testLoggingNotification(String clientType) throws InterruptedException { .logger("test-logger") .data("Another error message") .build())) - .thenReturn(new CallToolResult("Logging test completed", false)); + .thenReturn(CallToolResult.builder() + .content(List.of(new McpSchema.TextContent("Logging test completed"))) + .isError(false) + .build()); //@formatter:on }) .build(); var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().logging().tools(true).build()) + .capabilities(ServerCapabilities.builder().tools(true).build()) .tools(tool) .build(); @@ -1042,14 +1133,16 @@ void testLoggingNotification(String clientType) throws InterruptedException { assertThat(notificationMap.get("Another error message").logger()).isEqualTo("test-logger"); assertThat(notificationMap.get("Another error message").data()).isEqualTo("Another error message"); } - mcpServer.close(); + finally { + mcpServer.closeGracefully().block(); + } } // --------------------------------------- // Progress Tests // --------------------------------------- @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) + @MethodSource("clientsForTesting") void testProgressNotification(String clientType) throws InterruptedException { int expectedNotificationsCount = 4; // 3 notifications + 1 for another progress // token @@ -1064,7 +1157,7 @@ void testProgressNotification(String clientType) throws InterruptedException { .tool(McpSchema.Tool.builder() .name("progress-test") .description("Test progress notifications") - .inputSchema(emptyJsonSchema) + .inputSchema(EMPTY_JSON_SCHEMA) .build()) .callHandler((exchange, request) -> { @@ -1082,7 +1175,10 @@ void testProgressNotification(String clientType) throws InterruptedException { 0.0, 1.0, "Another processing started"))) .then(exchange.progressNotification( new McpSchema.ProgressNotification(progressToken, 1.0, 1.0, "Processing completed"))) - .thenReturn(new CallToolResult(("Progress test completed"), false)); + .thenReturn(CallToolResult.builder() + .content(List.of(new McpSchema.TextContent("Progress test completed"))) + .isError(false) + .build()); }) .build(); @@ -1147,7 +1243,7 @@ void testProgressNotification(String clientType) throws InterruptedException { assertThat(notificationMap.get("Processing completed").message()).isEqualTo("Processing completed"); } finally { - mcpServer.close(); + mcpServer.closeGracefully().block(); } } @@ -1155,7 +1251,7 @@ void testProgressNotification(String clientType) throws InterruptedException { // Completion Tests // --------------------------------------- @ParameterizedTest(name = "{0} : Completion call") - @ValueSource(strings = { "httpclient" }) + @MethodSource("clientsForTesting") void testCompletionShouldReturnExpectedSuggestions(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -1177,7 +1273,8 @@ void testCompletionShouldReturnExpectedSuggestions(String clientType) { List.of(new PromptArgument("language", "Language", "string", false))), (mcpSyncServerExchange, getPromptRequest) -> null)) .completions(new McpServerFeatures.SyncCompletionSpecification( - new McpSchema.PromptReference("ref/prompt", "code_review", "Code review"), completionHandler)) + new McpSchema.PromptReference(PromptReference.TYPE, "code_review", "Code review"), + completionHandler)) .build(); try (var mcpClient = clientBuilder.build()) { @@ -1186,7 +1283,7 @@ void testCompletionShouldReturnExpectedSuggestions(String clientType) { assertThat(initResult).isNotNull(); CompleteRequest request = new CompleteRequest( - new PromptReference("ref/prompt", "code_review", "Code review"), + new PromptReference(PromptReference.TYPE, "code_review", "Code review"), new CompleteRequest.CompleteArgument("language", "py")); CompleteResult result = mcpClient.completeCompletion(request); @@ -1195,17 +1292,18 @@ void testCompletionShouldReturnExpectedSuggestions(String clientType) { assertThat(samplingRequest.get().argument().name()).isEqualTo("language"); assertThat(samplingRequest.get().argument().value()).isEqualTo("py"); - assertThat(samplingRequest.get().ref().type()).isEqualTo("ref/prompt"); + assertThat(samplingRequest.get().ref().type()).isEqualTo(PromptReference.TYPE); + } + finally { + mcpServer.closeGracefully(); } - - mcpServer.close(); } // --------------------------------------- // Ping Tests // --------------------------------------- @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) + @MethodSource("clientsForTesting") void testPingSuccess(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -1217,7 +1315,7 @@ void testPingSuccess(String clientType) { .tool(Tool.builder() .name("ping-async-test") .description("Test ping async behavior") - .inputSchema(emptyJsonSchema) + .inputSchema(EMPTY_JSON_SCHEMA) .build()) .callHandler((exchange, request) -> { @@ -1234,7 +1332,10 @@ void testPingSuccess(String clientType) { assertThat(result).isNotNull(); }).then(Mono.fromCallable(() -> { executionOrder.set(executionOrder.get() + "3"); - return new CallToolResult("Async ping test completed", false); + return CallToolResult.builder() + .content(List.of(new McpSchema.TextContent("Async ping test completed"))) + .isError(false) + .build(); })); }) .build(); @@ -1259,15 +1360,16 @@ void testPingSuccess(String clientType) { // Verify execution order assertThat(executionOrder.get()).isEqualTo("123"); } - - mcpServer.close(); + finally { + mcpServer.closeGracefully().block(); + } } // --------------------------------------- // Tool Structured Output Schema Tests // --------------------------------------- @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) + @MethodSource("clientsForTesting") void testStructuredOutputValidationSuccess(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -1319,7 +1421,7 @@ void testStructuredOutputValidationSuccess(String clientType) { // In WebMVC, structured content is returned properly if (response.structuredContent() != null) { - assertThat(response.structuredContent()).containsEntry("result", 5.0) + assertThat((Map) response.structuredContent()).containsEntry("result", 5.0) .containsEntry("operation", "2 + 3") .containsEntry("timestamp", "2024-01-01T10:00:00Z"); } @@ -1335,8 +1437,125 @@ void testStructuredOutputValidationSuccess(String clientType) { .isEqualTo(json(""" {"result":5.0,"operation":"2 + 3","timestamp":"2024-01-01T10:00:00Z"}""")); } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testStructuredOutputOfObjectArrayValidationSuccess(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema that returns an array of objects + Map outputSchema = Map + .of( // @formatter:off + "type", "array", + "items", Map.of( + "type", "object", + "properties", Map.of( + "name", Map.of("type", "string"), + "age", Map.of("type", "number")), + "required", List.of("name", "age"))); // @formatter:on + + Tool calculatorTool = Tool.builder() + .name("getMembers") + .description("Returns a list of members") + .outputSchema(outputSchema) + .build(); + + McpServerFeatures.SyncToolSpecification tool = McpServerFeatures.SyncToolSpecification.builder() + .tool(calculatorTool) + .callHandler((exchange, request) -> { + return CallToolResult.builder() + .structuredContent(List.of(Map.of("name", "John", "age", 30), Map.of("name", "Peter", "age", 25))) + .build(); + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + assertThat(mcpClient.initialize()).isNotNull(); + + // Call tool with valid structured output of type array + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("getMembers", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isFalse(); + + assertThat(response.structuredContent()).isNotNull(); + assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isArray() + .hasSize(2) + .containsExactlyInAnyOrder(json(""" + {"name":"John","age":30}"""), json(""" + {"name":"Peter","age":25}""")); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testStructuredOutputWithInHandlerError(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of( + "type", "object", "properties", Map.of("result", Map.of("type", "number"), "operation", + Map.of("type", "string"), "timestamp", Map.of("type", "string")), + "required", List.of("result", "operation")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + // Handler that returns an error result + McpServerFeatures.SyncToolSpecification tool = McpServerFeatures.SyncToolSpecification.builder() + .tool(calculatorTool) + .callHandler((exchange, request) -> CallToolResult.builder() + .isError(true) + .content(List.of(new TextContent("Error calling tool: Simulated in-handler error"))) + .build()) + .build(); + + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Verify tool is listed with output schema + var toolsList = mcpClient.listTools(); + assertThat(toolsList.tools()).hasSize(1); + assertThat(toolsList.tools().get(0).name()).isEqualTo("calculator"); + // Note: outputSchema might be null in sync server, but validation still works + + // Call tool with valid structured output + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); - mcpServer.close(); + assertThat(response).isNotNull(); + assertThat(response.isError()).isTrue(); + assertThat(response.content()).isNotEmpty(); + assertThat(response.content()) + .containsExactly(new McpSchema.TextContent("Error calling tool: Simulated in-handler error")); + assertThat(response.structuredContent()).isNull(); + } + finally { + mcpServer.closeGracefully(); + } } @ParameterizedTest(name = "{0} : {displayName} ") @@ -1389,12 +1608,13 @@ void testStructuredOutputValidationFailure(String clientType) { String errorMessage = ((McpSchema.TextContent) response.content().get(0)).text(); assertThat(errorMessage).contains("Validation failed"); } - - mcpServer.close(); + finally { + mcpServer.closeGracefully(); + } } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) + @MethodSource("clientsForTesting") void testStructuredOutputMissingStructuredContent(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -1439,12 +1659,13 @@ void testStructuredOutputMissingStructuredContent(String clientType) { assertThat(errorMessage).isEqualTo( "Response missing structured content which is expected when calling tool with non-empty outputSchema"); } - - mcpServer.close(); + finally { + mcpServer.closeGracefully(); + } } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) + @MethodSource("clientsForTesting") void testStructuredOutputRuntimeToolAddition(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -1516,8 +1737,9 @@ void testStructuredOutputRuntimeToolAddition(String clientType) { .isEqualTo(json(""" {"count":3,"message":"Dynamic execution"}""")); } - - mcpServer.close(); + finally { + mcpServer.closeGracefully(); + } } private double evaluateExpression(String expression) { diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java similarity index 50% rename from mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java index 67579ce72..915c658e3 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java @@ -4,17 +4,8 @@ package io.modelcontextprotocol.server; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; -import static org.assertj.core.api.Assertions.assertThatThrownBy; - import java.util.List; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; - -import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; @@ -25,6 +16,14 @@ import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; import io.modelcontextprotocol.spec.McpSchema.Tool; import io.modelcontextprotocol.spec.McpServerTransportProvider; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import static io.modelcontextprotocol.util.ToolsUtils.EMPTY_JSON_SCHEMA; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatThrownBy; /** * Test suite for the {@link McpSyncServer} that can be used with different @@ -78,14 +77,14 @@ void testConstructorWithInvalidArguments() { void testGracefulShutdown() { var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } @Test void testImmediateClose() { var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - assertThatCode(() -> mcpSyncServer.close()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::close).doesNotThrowAnyException(); } @Test @@ -94,21 +93,13 @@ void testGetAsyncServer() { assertThat(mcpSyncServer.getAsyncServer()).isNotNull(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } // --------------------------------------- // Tools Tests // --------------------------------------- - String emptyJsonSchema = """ - { - "$schema": "http://json-schema.org/draft-07/schema#", - "type": "object", - "properties": {} - } - """; - @Test @Deprecated void testAddTool() { @@ -116,12 +107,16 @@ void testAddTool() { .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); - Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema); + Tool newTool = McpSchema.Tool.builder() + .name("new-tool") + .title("New test tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); assertThatCode(() -> mcpSyncServer.addTool(new McpServerFeatures.SyncToolSpecification(newTool, - (exchange, args) -> new CallToolResult(List.of(), false)))) + (exchange, args) -> CallToolResult.builder().content(List.of()).isError(false).build()))) .doesNotThrowAnyException(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } @Test @@ -130,75 +125,98 @@ void testAddToolCall() { .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); - Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema); + Tool newTool = McpSchema.Tool.builder() + .name("new-tool") + .title("New test tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + assertThatCode(() -> mcpSyncServer.addTool(McpServerFeatures.SyncToolSpecification.builder() .tool(newTool) - .callHandler((exchange, request) -> new CallToolResult(List.of(), false)) + .callHandler((exchange, request) -> CallToolResult.builder().content(List.of()).isError(false).build()) .build())).doesNotThrowAnyException(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } @Test @Deprecated void testAddDuplicateTool() { - Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); + Tool duplicateTool = McpSchema.Tool.builder() + .name(TEST_TOOL_NAME) + .title("Duplicate tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(duplicateTool, (exchange, args) -> new CallToolResult(List.of(), false)) + .tool(duplicateTool, (exchange, args) -> CallToolResult.builder().content(List.of()).isError(false).build()) .build(); - assertThatThrownBy(() -> mcpSyncServer.addTool(new McpServerFeatures.SyncToolSpecification(duplicateTool, - (exchange, args) -> new CallToolResult(List.of(), false)))) - .isInstanceOf(McpError.class) - .hasMessage("Tool with name '" + TEST_TOOL_NAME + "' already exists"); + assertThatCode(() -> mcpSyncServer.addTool(new McpServerFeatures.SyncToolSpecification(duplicateTool, + (exchange, args) -> CallToolResult.builder().content(List.of()).isError(false).build()))) + .doesNotThrowAnyException(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } @Test void testAddDuplicateToolCall() { - Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); + Tool duplicateTool = McpSchema.Tool.builder() + .name(TEST_TOOL_NAME) + .title("Duplicate tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .toolCall(duplicateTool, (exchange, request) -> new CallToolResult(List.of(), false)) + .toolCall(duplicateTool, + (exchange, request) -> CallToolResult.builder().content(List.of()).isError(false).build()) .build(); - assertThatThrownBy(() -> mcpSyncServer.addTool(McpServerFeatures.SyncToolSpecification.builder() + assertThatCode(() -> mcpSyncServer.addTool(McpServerFeatures.SyncToolSpecification.builder() .tool(duplicateTool) - .callHandler((exchange, request) -> new CallToolResult(List.of(), false)) - .build())).isInstanceOf(McpError.class) - .hasMessage("Tool with name '" + TEST_TOOL_NAME + "' already exists"); + .callHandler((exchange, request) -> CallToolResult.builder().content(List.of()).isError(false).build()) + .build())).doesNotThrowAnyException(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } @Test void testDuplicateToolCallDuringBuilding() { - Tool duplicateTool = new Tool("duplicate-build-toolcall", "Duplicate toolcall during building", - emptyJsonSchema); + Tool duplicateTool = McpSchema.Tool.builder() + .name("duplicate-build-toolcall") + .title("Duplicate toolcall during building") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); assertThatThrownBy(() -> prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .toolCall(duplicateTool, (exchange, request) -> new CallToolResult(List.of(), false)) - .toolCall(duplicateTool, (exchange, request) -> new CallToolResult(List.of(), false)) // Duplicate! + .toolCall(duplicateTool, + (exchange, request) -> CallToolResult.builder().content(List.of()).isError(false).build()) + .toolCall(duplicateTool, + (exchange, request) -> CallToolResult.builder().content(List.of()).isError(false).build()) // Duplicate! .build()).isInstanceOf(IllegalArgumentException.class) .hasMessage("Tool with name 'duplicate-build-toolcall' is already registered."); } @Test void testDuplicateToolsInBatchListRegistration() { - Tool duplicateTool = new Tool("batch-list-tool", "Duplicate tool in batch list", emptyJsonSchema); + Tool duplicateTool = McpSchema.Tool.builder() + .name("batch-list-tool") + .title("Duplicate tool in batch list") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); List specs = List.of( McpServerFeatures.SyncToolSpecification.builder() .tool(duplicateTool) - .callHandler((exchange, request) -> new CallToolResult(List.of(), false)) + .callHandler( + (exchange, request) -> CallToolResult.builder().content(List.of()).isError(false).build()) .build(), McpServerFeatures.SyncToolSpecification.builder() .tool(duplicateTool) - .callHandler((exchange, request) -> new CallToolResult(List.of(), false)) + .callHandler( + (exchange, request) -> CallToolResult.builder().content(List.of()).isError(false).build()) .build() // Duplicate! ); @@ -211,17 +229,22 @@ void testDuplicateToolsInBatchListRegistration() { @Test void testDuplicateToolsInBatchVarargsRegistration() { - Tool duplicateTool = new Tool("batch-varargs-tool", "Duplicate tool in batch varargs", emptyJsonSchema); + Tool duplicateTool = McpSchema.Tool.builder() + .name("batch-varargs-tool") + .title("Duplicate tool in batch varargs") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); assertThatThrownBy(() -> prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .tools(McpServerFeatures.SyncToolSpecification.builder() .tool(duplicateTool) - .callHandler((exchange, request) -> new CallToolResult(List.of(), false)) + .callHandler((exchange, request) -> CallToolResult.builder().content(List.of()).isError(false).build()) .build(), McpServerFeatures.SyncToolSpecification.builder() .tool(duplicateTool) - .callHandler((exchange, request) -> new CallToolResult(List.of(), false)) + .callHandler((exchange, + request) -> CallToolResult.builder().content(List.of()).isError(false).build()) .build() // Duplicate! ) .build()).isInstanceOf(IllegalArgumentException.class) @@ -230,16 +253,20 @@ void testDuplicateToolsInBatchVarargsRegistration() { @Test void testRemoveTool() { - Tool tool = new McpSchema.Tool(TEST_TOOL_NAME, "Test tool", emptyJsonSchema); + Tool tool = McpSchema.Tool.builder() + .name(TEST_TOOL_NAME) + .title("Test tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .toolCall(tool, (exchange, args) -> new CallToolResult(List.of(), false)) + .toolCall(tool, (exchange, args) -> CallToolResult.builder().content(List.of()).isError(false).build()) .build(); assertThatCode(() -> mcpSyncServer.removeTool(TEST_TOOL_NAME)).doesNotThrowAnyException(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } @Test @@ -248,19 +275,18 @@ void testRemoveNonexistentTool() { .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); - assertThatThrownBy(() -> mcpSyncServer.removeTool("nonexistent-tool")).isInstanceOf(McpError.class) - .hasMessage("Tool with name 'nonexistent-tool' not found"); + assertThatCode(() -> mcpSyncServer.removeTool("nonexistent-tool")).doesNotThrowAnyException(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } @Test void testNotifyToolsListChanged() { var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - assertThatCode(() -> mcpSyncServer.notifyToolsListChanged()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::notifyToolsListChanged).doesNotThrowAnyException(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } // --------------------------------------- @@ -271,9 +297,9 @@ void testNotifyToolsListChanged() { void testNotifyResourcesListChanged() { var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - assertThatCode(() -> mcpSyncServer.notifyResourcesListChanged()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::notifyResourcesListChanged).doesNotThrowAnyException(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } @Test @@ -284,7 +310,7 @@ void testNotifyResourcesUpdated() { .notifyResourcesUpdated(new McpSchema.ResourcesUpdatedNotification(TEST_RESOURCE_URI))) .doesNotThrowAnyException(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } @Test @@ -293,14 +319,18 @@ void testAddResource() { .capabilities(ServerCapabilities.builder().resources(true, false).build()) .build(); - Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", - null); + Resource resource = Resource.builder() + .uri(TEST_RESOURCE_URI) + .name("Test Resource") + .mimeType("text/plain") + .description("Test resource description") + .build(); McpServerFeatures.SyncResourceSpecification specification = new McpServerFeatures.SyncResourceSpecification( resource, (exchange, req) -> new ReadResourceResult(List.of())); assertThatCode(() -> mcpSyncServer.addResource(specification)).doesNotThrowAnyException(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } @Test @@ -310,31 +340,208 @@ void testAddResourceWithNullSpecification() { .build(); assertThatThrownBy(() -> mcpSyncServer.addResource((McpServerFeatures.SyncResourceSpecification) null)) - .isInstanceOf(McpError.class) + .isInstanceOf(IllegalArgumentException.class) .hasMessage("Resource must not be null"); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } @Test void testAddResourceWithoutCapability() { var serverWithoutResources = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", - null); + Resource resource = Resource.builder() + .uri(TEST_RESOURCE_URI) + .name("Test Resource") + .mimeType("text/plain") + .description("Test resource description") + .build(); McpServerFeatures.SyncResourceSpecification specification = new McpServerFeatures.SyncResourceSpecification( resource, (exchange, req) -> new ReadResourceResult(List.of())); - assertThatThrownBy(() -> serverWithoutResources.addResource(specification)).isInstanceOf(McpError.class) - .hasMessage("Server must be configured with resource capabilities"); + assertThatThrownBy(() -> serverWithoutResources.addResource(specification)) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Server must be configured with resource capabilities"); } @Test void testRemoveResourceWithoutCapability() { var serverWithoutResources = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - assertThatThrownBy(() -> serverWithoutResources.removeResource(TEST_RESOURCE_URI)).isInstanceOf(McpError.class) - .hasMessage("Server must be configured with resource capabilities"); + assertThatThrownBy(() -> serverWithoutResources.removeResource(TEST_RESOURCE_URI)) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Server must be configured with resource capabilities"); + } + + @Test + void testListResources() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + Resource resource = Resource.builder() + .uri(TEST_RESOURCE_URI) + .name("Test Resource") + .mimeType("text/plain") + .description("Test resource description") + .build(); + McpServerFeatures.SyncResourceSpecification specification = new McpServerFeatures.SyncResourceSpecification( + resource, (exchange, req) -> new ReadResourceResult(List.of())); + + mcpSyncServer.addResource(specification); + List resources = mcpSyncServer.listResources(); + + assertThat(resources).hasSize(1); + assertThat(resources.get(0).uri()).isEqualTo(TEST_RESOURCE_URI); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testRemoveResource() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + Resource resource = Resource.builder() + .uri(TEST_RESOURCE_URI) + .name("Test Resource") + .mimeType("text/plain") + .description("Test resource description") + .build(); + McpServerFeatures.SyncResourceSpecification specification = new McpServerFeatures.SyncResourceSpecification( + resource, (exchange, req) -> new ReadResourceResult(List.of())); + + mcpSyncServer.addResource(specification); + assertThatCode(() -> mcpSyncServer.removeResource(TEST_RESOURCE_URI)).doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testRemoveNonexistentResource() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + // Removing a non-existent resource should complete successfully (no error) + // as per the new implementation that just logs a warning + assertThatCode(() -> mcpSyncServer.removeResource("nonexistent://resource")).doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Resource Template Tests + // --------------------------------------- + + @Test + void testAddResourceTemplate() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + McpSchema.ResourceTemplate template = McpSchema.ResourceTemplate.builder() + .uriTemplate("test://template/{id}") + .name("test-template") + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.SyncResourceTemplateSpecification specification = new McpServerFeatures.SyncResourceTemplateSpecification( + template, (exchange, req) -> new ReadResourceResult(List.of())); + + assertThatCode(() -> mcpSyncServer.addResourceTemplate(specification)).doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testAddResourceTemplateWithoutCapability() { + // Create a server without resource capabilities + var serverWithoutResources = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + McpSchema.ResourceTemplate template = McpSchema.ResourceTemplate.builder() + .uriTemplate("test://template/{id}") + .name("test-template") + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.SyncResourceTemplateSpecification specification = new McpServerFeatures.SyncResourceTemplateSpecification( + template, (exchange, req) -> new ReadResourceResult(List.of())); + + assertThatThrownBy(() -> serverWithoutResources.addResourceTemplate(specification)) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Server must be configured with resource capabilities"); + } + + @Test + void testRemoveResourceTemplate() { + McpSchema.ResourceTemplate template = McpSchema.ResourceTemplate.builder() + .uriTemplate("test://template/{id}") + .name("test-template") + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.SyncResourceTemplateSpecification specification = new McpServerFeatures.SyncResourceTemplateSpecification( + template, (exchange, req) -> new ReadResourceResult(List.of())); + + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .resourceTemplates(specification) + .build(); + + assertThatCode(() -> mcpSyncServer.removeResourceTemplate("test://template/{id}")).doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testRemoveResourceTemplateWithoutCapability() { + // Create a server without resource capabilities + var serverWithoutResources = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + assertThatThrownBy(() -> serverWithoutResources.removeResourceTemplate("test://template/{id}")) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Server must be configured with resource capabilities"); + } + + @Test + void testRemoveNonexistentResourceTemplate() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + assertThatCode(() -> mcpSyncServer.removeResourceTemplate("nonexistent://template/{id}")) + .doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testListResourceTemplates() { + McpSchema.ResourceTemplate template = McpSchema.ResourceTemplate.builder() + .uriTemplate("test://template/{id}") + .name("test-template") + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.SyncResourceTemplateSpecification specification = new McpServerFeatures.SyncResourceTemplateSpecification( + template, (exchange, req) -> new ReadResourceResult(List.of())); + + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .resourceTemplates(specification) + .build(); + + List templates = mcpSyncServer.listResourceTemplates(); + + assertThat(templates).isNotNull(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } // --------------------------------------- @@ -345,9 +552,9 @@ void testRemoveResourceWithoutCapability() { void testNotifyPromptsListChanged() { var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - assertThatCode(() -> mcpSyncServer.notifyPromptsListChanged()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::notifyPromptsListChanged).doesNotThrowAnyException(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } @Test @@ -357,7 +564,7 @@ void testAddPromptWithNullSpecification() { .build(); assertThatThrownBy(() -> mcpSyncServer.addPrompt((McpServerFeatures.SyncPromptSpecification) null)) - .isInstanceOf(McpError.class) + .isInstanceOf(IllegalArgumentException.class) .hasMessage("Prompt specification must not be null"); } @@ -370,7 +577,8 @@ void testAddPromptWithoutCapability() { (exchange, req) -> new GetPromptResult("Test prompt description", List .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content"))))); - assertThatThrownBy(() -> serverWithoutPrompts.addPrompt(specification)).isInstanceOf(McpError.class) + assertThatThrownBy(() -> serverWithoutPrompts.addPrompt(specification)) + .isInstanceOf(IllegalStateException.class) .hasMessage("Server must be configured with prompt capabilities"); } @@ -378,7 +586,8 @@ void testAddPromptWithoutCapability() { void testRemovePromptWithoutCapability() { var serverWithoutPrompts = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - assertThatThrownBy(() -> serverWithoutPrompts.removePrompt(TEST_PROMPT_NAME)).isInstanceOf(McpError.class) + assertThatThrownBy(() -> serverWithoutPrompts.removePrompt(TEST_PROMPT_NAME)) + .isInstanceOf(IllegalStateException.class) .hasMessage("Server must be configured with prompt capabilities"); } @@ -396,7 +605,7 @@ void testRemovePrompt() { assertThatCode(() -> mcpSyncServer.removePrompt(TEST_PROMPT_NAME)).doesNotThrowAnyException(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } @Test @@ -405,10 +614,9 @@ void testRemoveNonexistentPrompt() { .capabilities(ServerCapabilities.builder().prompts(true).build()) .build(); - assertThatThrownBy(() -> mcpSyncServer.removePrompt("nonexistent-prompt")).isInstanceOf(McpError.class) - .hasMessage("Prompt with name 'nonexistent-prompt' not found"); + assertThatCode(() -> mcpSyncServer.removePrompt("nonexistent://template/{id}")).doesNotThrowAnyException(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } // --------------------------------------- @@ -429,9 +637,8 @@ void testRootsChangeHandlers() { } })) .build(); - assertThat(singleConsumerServer).isNotNull(); - assertThatCode(() -> singleConsumerServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(singleConsumerServer::closeGracefully).doesNotThrowAnyException(); onClose(); // Test with multiple consumers @@ -447,7 +654,7 @@ void testRootsChangeHandlers() { .build(); assertThat(multipleConsumersServer).isNotNull(); - assertThatCode(() -> multipleConsumersServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(multipleConsumersServer::closeGracefully).doesNotThrowAnyException(); onClose(); // Test error handling @@ -458,14 +665,14 @@ void testRootsChangeHandlers() { .build(); assertThat(errorHandlingServer).isNotNull(); - assertThatCode(() -> errorHandlingServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(errorHandlingServer::closeGracefully).doesNotThrowAnyException(); onClose(); // Test without consumers var noConsumersServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); assertThat(noConsumersServer).isNotNull(); - assertThatCode(() -> noConsumersServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(noConsumersServer::closeGracefully).doesNotThrowAnyException(); } } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/AsyncToolSpecificationBuilderTest.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/AsyncToolSpecificationBuilderTest.java similarity index 79% rename from mcp/src/test/java/io/modelcontextprotocol/server/AsyncToolSpecificationBuilderTest.java rename to mcp-core/src/test/java/io/modelcontextprotocol/server/AsyncToolSpecificationBuilderTest.java index 6744826c9..62332fcdb 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/AsyncToolSpecificationBuilderTest.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/AsyncToolSpecificationBuilderTest.java @@ -4,12 +4,14 @@ package io.modelcontextprotocol.server; +import static io.modelcontextprotocol.util.ToolsUtils.EMPTY_JSON_SCHEMA; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import java.util.List; import java.util.Map; +import io.modelcontextprotocol.spec.McpSchema; import org.junit.jupiter.api.Test; import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; @@ -26,21 +28,19 @@ */ class AsyncToolSpecificationBuilderTest { - String emptyJsonSchema = """ - { - "type": "object" - } - """; - @Test void builderShouldCreateValidAsyncToolSpecification() { - Tool tool = new Tool("test-tool", "A test tool", emptyJsonSchema); + Tool tool = McpSchema.Tool.builder() + .name("test-tool") + .title("A test tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); McpServerFeatures.AsyncToolSpecification specification = McpServerFeatures.AsyncToolSpecification.builder() .tool(tool) .callHandler((exchange, request) -> Mono - .just(new CallToolResult(List.of(new TextContent("Test result")), false))) + .just(CallToolResult.builder().content(List.of(new TextContent("Test result"))).isError(false).build())) .build(); assertThat(specification).isNotNull(); @@ -52,13 +52,18 @@ void builderShouldCreateValidAsyncToolSpecification() { @Test void builderShouldThrowExceptionWhenToolIsNull() { assertThatThrownBy(() -> McpServerFeatures.AsyncToolSpecification.builder() - .callHandler((exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) + .callHandler((exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) .build()).isInstanceOf(IllegalArgumentException.class).hasMessage("Tool must not be null"); } @Test void builderShouldThrowExceptionWhenCallToolIsNull() { - Tool tool = new Tool("test-tool", "A test tool", emptyJsonSchema); + Tool tool = McpSchema.Tool.builder() + .name("test-tool") + .title("A test tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); assertThatThrownBy(() -> McpServerFeatures.AsyncToolSpecification.builder().tool(tool).build()) .isInstanceOf(IllegalArgumentException.class) @@ -67,24 +72,36 @@ void builderShouldThrowExceptionWhenCallToolIsNull() { @Test void builderShouldAllowMethodChaining() { - Tool tool = new Tool("test-tool", "A test tool", emptyJsonSchema); + Tool tool = McpSchema.Tool.builder() + .name("test-tool") + .title("A test tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); McpServerFeatures.AsyncToolSpecification.Builder builder = McpServerFeatures.AsyncToolSpecification.builder(); // Then - verify method chaining returns the same builder instance assertThat(builder.tool(tool)).isSameAs(builder); - assertThat(builder.callHandler((exchange, request) -> Mono.just(new CallToolResult(List.of(), false)))) + assertThat(builder.callHandler( + (exchange, request) -> Mono.just(CallToolResult.builder().content(List.of()).isError(false).build()))) .isSameAs(builder); } @Test void builtSpecificationShouldExecuteCallToolCorrectly() { - Tool tool = new Tool("calculator", "Simple calculator", emptyJsonSchema); + Tool tool = McpSchema.Tool.builder() + .name("calculator") + .title("Simple calculator") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); String expectedResult = "42"; McpServerFeatures.AsyncToolSpecification specification = McpServerFeatures.AsyncToolSpecification.builder() .tool(tool) .callHandler((exchange, request) -> { - return Mono.just(new CallToolResult(List.of(new TextContent(expectedResult)), false)); + return Mono.just(CallToolResult.builder() + .content(List.of(new TextContent(expectedResult))) + .isError(false) + .build()); }) .build(); @@ -103,13 +120,20 @@ void builtSpecificationShouldExecuteCallToolCorrectly() { @Test @SuppressWarnings("deprecation") void deprecatedConstructorShouldWorkCorrectly() { - Tool tool = new Tool("deprecated-tool", "A deprecated tool", emptyJsonSchema); + Tool tool = McpSchema.Tool.builder() + .name("deprecated-tool") + .title("A deprecated tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); String expectedResult = "deprecated result"; // Test the deprecated constructor that takes a 'call' function McpServerFeatures.AsyncToolSpecification specification = new McpServerFeatures.AsyncToolSpecification(tool, - (exchange, arguments) -> Mono - .just(new CallToolResult(List.of(new TextContent(expectedResult)), false))); + (exchange, + arguments) -> Mono.just(CallToolResult.builder() + .content(List.of(new TextContent(expectedResult))) + .isError(false) + .build())); assertThat(specification).isNotNull(); assertThat(specification.tool()).isEqualTo(tool); @@ -143,13 +167,20 @@ void deprecatedConstructorShouldWorkCorrectly() { @Test void fromSyncShouldConvertSyncToolSpecificationCorrectly() { - Tool tool = new Tool("sync-tool", "A sync tool", emptyJsonSchema); + Tool tool = McpSchema.Tool.builder() + .name("sync-tool") + .title("A sync tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); String expectedResult = "sync result"; // Create a sync tool specification McpServerFeatures.SyncToolSpecification syncSpec = McpServerFeatures.SyncToolSpecification.builder() .tool(tool) - .callHandler((exchange, request) -> new CallToolResult(List.of(new TextContent(expectedResult)), false)) + .callHandler((exchange, request) -> CallToolResult.builder() + .content(List.of(new TextContent(expectedResult))) + .isError(false) + .build()) .build(); // Convert to async using fromSync @@ -178,14 +209,21 @@ void fromSyncShouldConvertSyncToolSpecificationCorrectly() { @Test @SuppressWarnings("deprecation") void fromSyncShouldConvertSyncToolSpecificationWithDeprecatedCallCorrectly() { - Tool tool = new Tool("sync-deprecated-tool", "A sync tool with deprecated call", emptyJsonSchema); + Tool tool = McpSchema.Tool.builder() + .name("sync-deprecated-tool") + .title("A sync tool with deprecated call") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); String expectedResult = "sync deprecated result"; McpAsyncServerExchange nullExchange = null; // Mock or create a suitable exchange // if needed // Create a sync tool specification using the deprecated constructor McpServerFeatures.SyncToolSpecification syncSpec = new McpServerFeatures.SyncToolSpecification(tool, - (exchange, arguments) -> new CallToolResult(List.of(new TextContent(expectedResult)), false)); + (exchange, arguments) -> CallToolResult.builder() + .content(List.of(new TextContent(expectedResult))) + .isError(false) + .build()); // Convert to async using fromSync McpServerFeatures.AsyncToolSpecification asyncSpec = McpServerFeatures.AsyncToolSpecification diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletSseIntegrationTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/HttpServletSseIntegrationTests.java similarity index 84% rename from mcp/src/test/java/io/modelcontextprotocol/server/HttpServletSseIntegrationTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/server/HttpServletSseIntegrationTests.java index 823c28d8e..d2b9d14d0 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletSseIntegrationTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/HttpServletSseIntegrationTests.java @@ -4,25 +4,27 @@ package io.modelcontextprotocol.server; -import static org.assertj.core.api.Assertions.assertThat; - import java.time.Duration; +import java.util.Map; +import java.util.stream.Stream; +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.server.McpServer.AsyncSpecification; +import io.modelcontextprotocol.server.McpServer.SyncSpecification; +import io.modelcontextprotocol.server.transport.HttpServletSseServerTransportProvider; +import io.modelcontextprotocol.server.transport.TomcatTestUtil; +import jakarta.servlet.http.HttpServletRequest; import org.apache.catalina.LifecycleException; import org.apache.catalina.LifecycleState; import org.apache.catalina.startup.Tomcat; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.provider.Arguments; -import com.fasterxml.jackson.databind.ObjectMapper; - -import io.modelcontextprotocol.client.McpClient; -import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; -import io.modelcontextprotocol.server.McpServer.AsyncSpecification; -import io.modelcontextprotocol.server.McpServer.SyncSpecification; -import io.modelcontextprotocol.server.transport.HttpServletSseServerTransportProvider; -import io.modelcontextprotocol.server.transport.TomcatTestUtil; +import static org.assertj.core.api.Assertions.assertThat; @Timeout(15) class HttpServletSseIntegrationTests extends AbstractMcpClientServerIntegrationTests { @@ -37,11 +39,15 @@ class HttpServletSseIntegrationTests extends AbstractMcpClientServerIntegrationT private Tomcat tomcat; + static Stream clientsForTesting() { + return Stream.of(Arguments.of("httpclient")); + } + @BeforeEach public void before() { // Create and configure the transport provider mcpServerTransportProvider = HttpServletSseServerTransportProvider.builder() - .objectMapper(new ObjectMapper()) + .contextExtractor(TEST_CONTEXT_EXTRACTOR) .messageEndpoint(CUSTOM_MESSAGE_ENDPOINT) .sseEndpoint(CUSTOM_SSE_ENDPOINT) .build(); @@ -92,4 +98,7 @@ public void after() { protected void prepareClients(int port, String mcpEndpoint) { } + static McpTransportContextExtractor TEST_CONTEXT_EXTRACTOR = (r) -> McpTransportContext + .create(Map.of("important", "value")); + } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStatelessIntegrationTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/HttpServletStatelessIntegrationTests.java similarity index 76% rename from mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStatelessIntegrationTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/server/HttpServletStatelessIntegrationTests.java index a8951e6dc..491c2d4ed 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStatelessIntegrationTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/HttpServletStatelessIntegrationTests.java @@ -4,22 +4,30 @@ package io.modelcontextprotocol.server; -import com.fasterxml.jackson.databind.ObjectMapper; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiFunction; + import io.modelcontextprotocol.client.McpClient; import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.common.McpTransportContext; import io.modelcontextprotocol.server.transport.HttpServletStatelessServerTransport; import io.modelcontextprotocol.server.transport.TomcatTestUtil; import io.modelcontextprotocol.spec.HttpHeaders; -import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; import io.modelcontextprotocol.spec.McpSchema.CompleteRequest; import io.modelcontextprotocol.spec.McpSchema.CompleteResult; +import io.modelcontextprotocol.spec.McpSchema.ErrorCodes; import io.modelcontextprotocol.spec.McpSchema.InitializeResult; import io.modelcontextprotocol.spec.McpSchema.Prompt; import io.modelcontextprotocol.spec.McpSchema.PromptArgument; import io.modelcontextprotocol.spec.McpSchema.PromptReference; import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.TextContent; import io.modelcontextprotocol.spec.McpSchema.Tool; import io.modelcontextprotocol.spec.ProtocolVersions; import net.javacrumbs.jsonunit.core.Option; @@ -32,19 +40,15 @@ import org.junit.jupiter.api.Timeout; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; + import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.web.client.RestClient; -import java.time.Duration; -import java.util.List; -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.BiFunction; - import static io.modelcontextprotocol.server.transport.HttpServletStatelessServerTransport.APPLICATION_JSON; import static io.modelcontextprotocol.server.transport.HttpServletStatelessServerTransport.TEXT_EVENT_STREAM; +import static io.modelcontextprotocol.util.McpJsonMapperUtils.JSON_MAPPER; +import static io.modelcontextprotocol.util.ToolsUtils.EMPTY_JSON_SCHEMA; import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson; import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json; import static org.assertj.core.api.Assertions.assertThat; @@ -66,7 +70,6 @@ class HttpServletStatelessIntegrationTests { @BeforeEach public void before() { this.mcpStatelessServerTransport = HttpServletStatelessServerTransport.builder() - .objectMapper(new ObjectMapper()) .messageEndpoint(CUSTOM_MESSAGE_ENDPOINT) .build(); @@ -105,24 +108,19 @@ public void after() { // --------------------------------------- // Tools Tests // --------------------------------------- - - String emptyJsonSchema = """ - { - "$schema": "http://json-schema.org/draft-07/schema#", - "type": "object", - "properties": {} - } - """; - @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "httpclient" }) void testToolCallSuccess(String clientType) { var clientBuilder = clientBuilders.get(clientType); - var callResponse = new CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + var callResponse = CallToolResult.builder() + .content(List.of(new McpSchema.TextContent("CALL RESPONSE"))) + .isError(false) + .build(); McpStatelessServerFeatures.SyncToolSpecification tool1 = new McpStatelessServerFeatures.SyncToolSpecification( - new Tool("tool1", "tool1 description", emptyJsonSchema), (transportContext, request) -> { + Tool.builder().name("tool1").title("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build(), + (transportContext, request) -> { // perform a blocking call to a remote service String response = RestClient.create() .get() @@ -150,8 +148,9 @@ void testToolCallSuccess(String clientType) { assertThat(response).isNotNull(); assertThat(response).isEqualTo(callResponse); } - - mcpServer.close(); + finally { + mcpServer.close(); + } } @ParameterizedTest(name = "{0} : {displayName} ") @@ -166,8 +165,9 @@ void testInitialize(String clientType) { InitializeResult initResult = mcpClient.initialize(); assertThat(initResult).isNotNull(); } - - mcpServer.close(); + finally { + mcpServer.close(); + } } // --------------------------------------- @@ -197,7 +197,7 @@ void testCompletionShouldReturnExpectedSuggestions(String clientType) { List.of(new PromptArgument("language", "Language", "string", false))), (transportContext, getPromptRequest) -> null)) .completions(new McpStatelessServerFeatures.SyncCompletionSpecification( - new PromptReference("ref/prompt", "code_review", "Code review"), completionHandler)) + new PromptReference(PromptReference.TYPE, "code_review", "Code review"), completionHandler)) .build(); try (var mcpClient = clientBuilder.build()) { @@ -206,7 +206,7 @@ void testCompletionShouldReturnExpectedSuggestions(String clientType) { assertThat(initResult).isNotNull(); CompleteRequest request = new CompleteRequest( - new PromptReference("ref/prompt", "code_review", "Code review"), + new PromptReference(PromptReference.TYPE, "code_review", "Code review"), new CompleteRequest.CompleteArgument("language", "py")); CompleteResult result = mcpClient.completeCompletion(request); @@ -215,10 +215,11 @@ void testCompletionShouldReturnExpectedSuggestions(String clientType) { assertThat(samplingRequest.get().argument().name()).isEqualTo("language"); assertThat(samplingRequest.get().argument().value()).isEqualTo("py"); - assertThat(samplingRequest.get().ref().type()).isEqualTo("ref/prompt"); + assertThat(samplingRequest.get().ref().type()).isEqualTo(PromptReference.TYPE); + } + finally { + mcpServer.close(); } - - mcpServer.close(); } // --------------------------------------- @@ -289,8 +290,129 @@ void testStructuredOutputValidationSuccess(String clientType) { .isEqualTo(json(""" {"result":5.0,"operation":"2 + 3","timestamp":"2024-01-01T10:00:00Z"}""")); } + finally { + mcpServer.close(); + } + } - mcpServer.close(); + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testStructuredOutputOfObjectArrayValidationSuccess(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema that returns an array of objects + Map outputSchema = Map + .of( // @formatter:off + "type", "array", + "items", Map.of( + "type", "object", + "properties", Map.of( + "name", Map.of("type", "string"), + "age", Map.of("type", "number")), + "required", List.of("name", "age"))); // @formatter:on + + Tool calculatorTool = Tool.builder() + .name("getMembers") + .description("Returns a list of members") + .outputSchema(outputSchema) + .build(); + + McpStatelessServerFeatures.SyncToolSpecification tool = McpStatelessServerFeatures.SyncToolSpecification + .builder() + .tool(calculatorTool) + .callHandler((exchange, request) -> { + return CallToolResult.builder() + .structuredContent(List.of(Map.of("name", "John", "age", 30), Map.of("name", "Peter", "age", 25))) + .build(); + }) + .build(); + + var mcpServer = McpServer.sync(mcpStatelessServerTransport) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + assertThat(mcpClient.initialize()).isNotNull(); + + // Call tool with valid structured output of type array + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("getMembers", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isFalse(); + + assertThat(response.structuredContent()).isNotNull(); + assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isArray() + .hasSize(2) + .containsExactlyInAnyOrder(json(""" + {"name":"John","age":30}"""), json(""" + {"name":"Peter","age":25}""")); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testStructuredOutputWithInHandlerError(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of( + "type", "object", "properties", Map.of("result", Map.of("type", "number"), "operation", + Map.of("type", "string"), "timestamp", Map.of("type", "string")), + "required", List.of("result", "operation")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + // Handler that returns an error result + McpStatelessServerFeatures.SyncToolSpecification tool = McpStatelessServerFeatures.SyncToolSpecification + .builder() + .tool(calculatorTool) + .callHandler((exchange, request) -> CallToolResult.builder() + .isError(true) + .content(List.of(new TextContent("Error calling tool: Simulated in-handler error"))) + .build()) + .build(); + + var mcpServer = McpServer.sync(mcpStatelessServerTransport) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Verify tool is listed with output schema + var toolsList = mcpClient.listTools(); + assertThat(toolsList.tools()).hasSize(1); + assertThat(toolsList.tools().get(0).name()).isEqualTo("calculator"); + // Note: outputSchema might be null in sync server, but validation still works + + // Call tool with valid structured output + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isTrue(); + assertThat(response.content()).isNotEmpty(); + assertThat(response.content()) + .containsExactly(new McpSchema.TextContent("Error calling tool: Simulated in-handler error")); + assertThat(response.structuredContent()).isNull(); + } + finally { + mcpServer.closeGracefully(); + } } @ParameterizedTest(name = "{0} : {displayName} ") @@ -341,8 +463,9 @@ void testStructuredOutputValidationFailure(String clientType) { String errorMessage = ((McpSchema.TextContent) response.content().get(0)).text(); assertThat(errorMessage).contains("Validation failed"); } - - mcpServer.close(); + finally { + mcpServer.close(); + } } @ParameterizedTest(name = "{0} : {displayName} ") @@ -390,8 +513,9 @@ void testStructuredOutputMissingStructuredContent(String clientType) { assertThat(errorMessage).isEqualTo( "Response missing structured content which is expected when calling tool with non-empty outputSchema"); } - - mcpServer.close(); + finally { + mcpServer.close(); + } } @ParameterizedTest(name = "{0} : {displayName} ") @@ -464,12 +588,13 @@ void testStructuredOutputRuntimeToolAddition(String clientType) { .isEqualTo(json(""" {"count":3,"message":"Dynamic execution"}""")); } - - mcpServer.close(); + finally { + mcpServer.close(); + } } @Test - void testThrownMcpError() throws Exception { + void testThrownMcpErrorAndJsonRpcError() throws Exception { var mcpServer = McpServer.sync(mcpStatelessServerTransport) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) @@ -479,7 +604,7 @@ void testThrownMcpError() throws Exception { McpStatelessServerFeatures.SyncToolSpecification toolSpec = new McpStatelessServerFeatures.SyncToolSpecification( testTool, (transportContext, request) -> { - throw new McpError(new McpSchema.JSONRPCResponse.JSONRPCError(12345, "testing", Map.of("a", "b"))); + throw new RuntimeException("testing"); }); mcpServer.addTool(toolSpec); @@ -491,7 +616,7 @@ void testThrownMcpError() throws Exception { MockHttpServletRequest request = new MockHttpServletRequest("POST", CUSTOM_MESSAGE_ENDPOINT); MockHttpServletResponse response = new MockHttpServletResponse(); - byte[] content = new ObjectMapper().writeValueAsBytes(jsonrpcRequest); + byte[] content = JSON_MAPPER.writeValueAsBytes(jsonrpcRequest); request.setContent(content); request.addHeader("Content-Type", "application/json"); request.addHeader("Content-Length", Integer.toString(content.length)); @@ -500,13 +625,16 @@ void testThrownMcpError() throws Exception { request.addHeader("Content-Type", APPLICATION_JSON); request.addHeader("Cache-Control", "no-cache"); request.addHeader(HttpHeaders.PROTOCOL_VERSION, ProtocolVersions.MCP_2025_03_26); + mcpStatelessServerTransport.service(request, response); - McpSchema.JSONRPCResponse jsonrpcResponse = new ObjectMapper().readValue(response.getContentAsByteArray(), + McpSchema.JSONRPCResponse jsonrpcResponse = JSON_MAPPER.readValue(response.getContentAsByteArray(), McpSchema.JSONRPCResponse.class); - assertThat(jsonrpcResponse.error()) - .isEqualTo(new McpSchema.JSONRPCResponse.JSONRPCError(12345, "testing", Map.of("a", "b"))); + assertThat(jsonrpcResponse).isNotNull(); + assertThat(jsonrpcResponse.error()).isNotNull(); + assertThat(jsonrpcResponse.error().code()).isEqualTo(ErrorCodes.INTERNAL_ERROR); + assertThat(jsonrpcResponse.error().message()).isEqualTo("testing"); mcpServer.close(); } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableAsyncServerTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableAsyncServerTests.java similarity index 79% rename from mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableAsyncServerTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableAsyncServerTests.java index 327ec1b21..96f1524b7 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableAsyncServerTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableAsyncServerTests.java @@ -6,8 +6,6 @@ import org.junit.jupiter.api.Timeout; -import com.fasterxml.jackson.databind.ObjectMapper; - import io.modelcontextprotocol.server.transport.HttpServletStreamableServerTransportProvider; import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; @@ -21,10 +19,7 @@ class HttpServletStreamableAsyncServerTests extends AbstractMcpAsyncServerTests { protected McpStreamableServerTransportProvider createMcpTransportProvider() { - return HttpServletStreamableServerTransportProvider.builder() - .objectMapper(new ObjectMapper()) - .mcpEndpoint("/mcp/message") - .build(); + return HttpServletStreamableServerTransportProvider.builder().mcpEndpoint("/mcp/message").build(); } @Override diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableIntegrationTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableIntegrationTests.java similarity index 84% rename from mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableIntegrationTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableIntegrationTests.java index 8a8675d95..81423e0c5 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableIntegrationTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableIntegrationTests.java @@ -4,25 +4,27 @@ package io.modelcontextprotocol.server; -import static org.assertj.core.api.Assertions.assertThat; - import java.time.Duration; +import java.util.Map; +import java.util.stream.Stream; +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.server.McpServer.AsyncSpecification; +import io.modelcontextprotocol.server.McpServer.SyncSpecification; +import io.modelcontextprotocol.server.transport.HttpServletStreamableServerTransportProvider; +import io.modelcontextprotocol.server.transport.TomcatTestUtil; +import jakarta.servlet.http.HttpServletRequest; import org.apache.catalina.LifecycleException; import org.apache.catalina.LifecycleState; import org.apache.catalina.startup.Tomcat; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.provider.Arguments; -import com.fasterxml.jackson.databind.ObjectMapper; - -import io.modelcontextprotocol.client.McpClient; -import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; -import io.modelcontextprotocol.server.McpServer.AsyncSpecification; -import io.modelcontextprotocol.server.McpServer.SyncSpecification; -import io.modelcontextprotocol.server.transport.HttpServletStreamableServerTransportProvider; -import io.modelcontextprotocol.server.transport.TomcatTestUtil; +import static org.assertj.core.api.Assertions.assertThat; @Timeout(15) class HttpServletStreamableIntegrationTests extends AbstractMcpClientServerIntegrationTests { @@ -35,11 +37,15 @@ class HttpServletStreamableIntegrationTests extends AbstractMcpClientServerInteg private Tomcat tomcat; + static Stream clientsForTesting() { + return Stream.of(Arguments.of("httpclient")); + } + @BeforeEach public void before() { // Create and configure the transport provider mcpServerTransportProvider = HttpServletStreamableServerTransportProvider.builder() - .objectMapper(new ObjectMapper()) + .contextExtractor(TEST_CONTEXT_EXTRACTOR) .mcpEndpoint(MESSAGE_ENDPOINT) .keepAliveInterval(Duration.ofSeconds(1)) .build(); @@ -90,4 +96,7 @@ public void after() { protected void prepareClients(int port, String mcpEndpoint) { } + static McpTransportContextExtractor TEST_CONTEXT_EXTRACTOR = (r) -> McpTransportContext + .create(Map.of("important", "value")); + } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableSyncServerTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableSyncServerTests.java similarity index 79% rename from mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableSyncServerTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableSyncServerTests.java index 66fa2b2ac..87c0712dc 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableSyncServerTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableSyncServerTests.java @@ -6,8 +6,6 @@ import org.junit.jupiter.api.Timeout; -import com.fasterxml.jackson.databind.ObjectMapper; - import io.modelcontextprotocol.server.transport.HttpServletStreamableServerTransportProvider; import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; @@ -21,10 +19,7 @@ class HttpServletStreamableSyncServerTests extends AbstractMcpSyncServerTests { protected McpStreamableServerTransportProvider createMcpTransportProvider() { - return HttpServletStreamableServerTransportProvider.builder() - .objectMapper(new ObjectMapper()) - .mcpEndpoint("/mcp/message") - .build(); + return HttpServletStreamableServerTransportProvider.builder().mcpEndpoint("/mcp/message").build(); } @Override diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/McpAsyncServerExchangeTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/McpAsyncServerExchangeTests.java similarity index 96% rename from mcp/src/test/java/io/modelcontextprotocol/server/McpAsyncServerExchangeTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/server/McpAsyncServerExchangeTests.java index 987c43663..640d34c9c 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/McpAsyncServerExchangeTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/McpAsyncServerExchangeTests.java @@ -4,15 +4,16 @@ package io.modelcontextprotocol.server; +import io.modelcontextprotocol.common.McpTransportContext; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Map; -import com.fasterxml.jackson.core.type.TypeReference; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.json.TypeRef; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.Mock; @@ -54,7 +55,7 @@ void setUp() { clientInfo = new McpSchema.Implementation("test-client", "1.0.0"); exchange = new McpAsyncServerExchange("testSessionId", mockSession, clientCapabilities, clientInfo, - new DefaultMcpTransportContext()); + McpTransportContext.EMPTY); } @Test @@ -65,7 +66,7 @@ void testListRootsWithSinglePage() { McpSchema.ListRootsResult singlePageResult = new McpSchema.ListRootsResult(roots, null); when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), any(McpSchema.PaginatedRequest.class), - any(TypeReference.class))) + any(TypeRef.class))) .thenReturn(Mono.just(singlePageResult)); StepVerifier.create(exchange.listRoots()).assertNext(result -> { @@ -93,11 +94,11 @@ void testListRootsWithMultiplePages() { McpSchema.ListRootsResult page2Result = new McpSchema.ListRootsResult(page2Roots, null); when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), eq(new McpSchema.PaginatedRequest(null)), - any(TypeReference.class))) + any(TypeRef.class))) .thenReturn(Mono.just(page1Result)); when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), eq(new McpSchema.PaginatedRequest("cursor1")), - any(TypeReference.class))) + any(TypeRef.class))) .thenReturn(Mono.just(page2Result)); StepVerifier.create(exchange.listRoots()).assertNext(result -> { @@ -119,7 +120,7 @@ void testListRootsWithEmptyResult() { McpSchema.ListRootsResult emptyResult = new McpSchema.ListRootsResult(new ArrayList<>(), null); when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), any(McpSchema.PaginatedRequest.class), - any(TypeReference.class))) + any(TypeRef.class))) .thenReturn(Mono.just(emptyResult)); StepVerifier.create(exchange.listRoots()).assertNext(result -> { @@ -139,7 +140,7 @@ void testListRootsWithSpecificCursor() { McpSchema.ListRootsResult result = new McpSchema.ListRootsResult(roots, "nextCursor"); when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), eq(new McpSchema.PaginatedRequest("someCursor")), - any(TypeReference.class))) + any(TypeRef.class))) .thenReturn(Mono.just(result)); StepVerifier.create(exchange.listRoots("someCursor")).assertNext(listResult -> { @@ -153,7 +154,7 @@ void testListRootsWithSpecificCursor() { void testListRootsWithError() { when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), any(McpSchema.PaginatedRequest.class), - any(TypeReference.class))) + any(TypeRef.class))) .thenReturn(Mono.error(new RuntimeException("Network error"))); // When & Then @@ -174,11 +175,11 @@ void testListRootsUnmodifiabilityAfterAccumulation() { McpSchema.ListRootsResult page2Result = new McpSchema.ListRootsResult(page2Roots, null); when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), eq(new McpSchema.PaginatedRequest(null)), - any(TypeReference.class))) + any(TypeRef.class))) .thenReturn(Mono.just(page1Result)); when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), eq(new McpSchema.PaginatedRequest("cursor1")), - any(TypeReference.class))) + any(TypeRef.class))) .thenReturn(Mono.just(page2Result)); StepVerifier.create(exchange.listRoots()).assertNext(result -> { @@ -313,8 +314,7 @@ void testCreateElicitationWithNullCapabilities() { }); // Verify that sendRequest was never called due to null capabilities - verify(mockSession, never()).sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), any(), - any(TypeReference.class)); + verify(mockSession, never()).sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), any(), any(TypeRef.class)); } @Test @@ -338,8 +338,7 @@ void testCreateElicitationWithoutElicitationCapabilities() { // Verify that sendRequest was never called due to missing elicitation // capabilities - verify(mockSession, never()).sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), any(), - any(TypeReference.class)); + verify(mockSession, never()).sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), any(), any(TypeRef.class)); } @Test @@ -373,8 +372,7 @@ void testCreateElicitationWithComplexRequest() { .content(responseContent) .build(); - when(mockSession.sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), eq(elicitRequest), - any(TypeReference.class))) + when(mockSession.sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), eq(elicitRequest), any(TypeRef.class))) .thenReturn(Mono.just(expectedResult)); StepVerifier.create(exchangeWithElicitation.createElicitation(elicitRequest)).assertNext(result -> { @@ -404,8 +402,7 @@ void testCreateElicitationWithDeclineAction() { .message(McpSchema.ElicitResult.Action.DECLINE) .build(); - when(mockSession.sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), eq(elicitRequest), - any(TypeReference.class))) + when(mockSession.sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), eq(elicitRequest), any(TypeRef.class))) .thenReturn(Mono.just(expectedResult)); StepVerifier.create(exchangeWithElicitation.createElicitation(elicitRequest)).assertNext(result -> { @@ -432,8 +429,7 @@ void testCreateElicitationWithCancelAction() { .message(McpSchema.ElicitResult.Action.CANCEL) .build(); - when(mockSession.sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), eq(elicitRequest), - any(TypeReference.class))) + when(mockSession.sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), eq(elicitRequest), any(TypeRef.class))) .thenReturn(Mono.just(expectedResult)); StepVerifier.create(exchangeWithElicitation.createElicitation(elicitRequest)).assertNext(result -> { @@ -456,8 +452,7 @@ void testCreateElicitationWithSessionError() { .message("Please provide your name") .build(); - when(mockSession.sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), eq(elicitRequest), - any(TypeReference.class))) + when(mockSession.sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), eq(elicitRequest), any(TypeRef.class))) .thenReturn(Mono.error(new RuntimeException("Session communication error"))); StepVerifier.create(exchangeWithElicitation.createElicitation(elicitRequest)).verifyErrorSatisfies(error -> { @@ -487,7 +482,7 @@ void testCreateMessageWithNullCapabilities() { // Verify that sendRequest was never called due to null capabilities verify(mockSession, never()).sendRequest(eq(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE), any(), - any(TypeReference.class)); + any(TypeRef.class)); } @Test @@ -512,7 +507,7 @@ void testCreateMessageWithoutSamplingCapabilities() { // Verify that sendRequest was never called due to missing sampling capabilities verify(mockSession, never()).sendRequest(eq(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE), any(), - any(TypeReference.class)); + any(TypeRef.class)); } @Test @@ -538,7 +533,7 @@ void testCreateMessageWithBasicRequest() { .build(); when(mockSession.sendRequest(eq(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE), eq(createMessageRequest), - any(TypeReference.class))) + any(TypeRef.class))) .thenReturn(Mono.just(expectedResult)); StepVerifier.create(exchangeWithSampling.createMessage(createMessageRequest)).assertNext(result -> { @@ -576,7 +571,7 @@ void testCreateMessageWithImageContent() { .build(); when(mockSession.sendRequest(eq(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE), eq(createMessageRequest), - any(TypeReference.class))) + any(TypeRef.class))) .thenReturn(Mono.just(expectedResult)); StepVerifier.create(exchangeWithSampling.createMessage(createMessageRequest)).assertNext(result -> { @@ -602,7 +597,7 @@ void testCreateMessageWithSessionError() { .build(); when(mockSession.sendRequest(eq(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE), eq(createMessageRequest), - any(TypeReference.class))) + any(TypeRef.class))) .thenReturn(Mono.error(new RuntimeException("Session communication error"))); StepVerifier.create(exchangeWithSampling.createMessage(createMessageRequest)).verifyErrorSatisfies(error -> { @@ -634,7 +629,7 @@ void testCreateMessageWithIncludeContext() { .build(); when(mockSession.sendRequest(eq(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE), eq(createMessageRequest), - any(TypeReference.class))) + any(TypeRef.class))) .thenReturn(Mono.just(expectedResult)); StepVerifier.create(exchangeWithSampling.createMessage(createMessageRequest)).assertNext(result -> { @@ -652,7 +647,7 @@ void testPingWithSuccessfulResponse() { java.util.Map expectedResponse = java.util.Map.of(); - when(mockSession.sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeReference.class))) + when(mockSession.sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeRef.class))) .thenReturn(Mono.just(expectedResponse)); StepVerifier.create(exchange.ping()).assertNext(result -> { @@ -661,14 +656,14 @@ void testPingWithSuccessfulResponse() { }).verifyComplete(); // Verify that sendRequest was called with correct parameters - verify(mockSession, times(1)).sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeReference.class)); + verify(mockSession, times(1)).sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeRef.class)); } @Test void testPingWithMcpError() { // Given - Mock an MCP-specific error during ping McpError mcpError = new McpError("Server unavailable"); - when(mockSession.sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeReference.class))) + when(mockSession.sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeRef.class))) .thenReturn(Mono.error(mcpError)); // When & Then @@ -676,13 +671,13 @@ void testPingWithMcpError() { assertThat(error).isInstanceOf(McpError.class).hasMessage("Server unavailable"); }); - verify(mockSession, times(1)).sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeReference.class)); + verify(mockSession, times(1)).sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeRef.class)); } @Test void testPingMultipleCalls() { - when(mockSession.sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeReference.class))) + when(mockSession.sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeRef.class))) .thenReturn(Mono.just(Map.of())) .thenReturn(Mono.just(Map.of())); @@ -697,7 +692,7 @@ void testPingMultipleCalls() { }).verifyComplete(); // Verify that sendRequest was called twice - verify(mockSession, times(2)).sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeReference.class)); + verify(mockSession, times(2)).sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeRef.class)); } } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/McpCompletionTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/McpCompletionTests.java similarity index 93% rename from mcp/src/test/java/io/modelcontextprotocol/server/McpCompletionTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/server/McpCompletionTests.java index f915895be..54fb80a78 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/McpCompletionTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/McpCompletionTests.java @@ -12,14 +12,13 @@ import org.apache.catalina.LifecycleException; import org.apache.catalina.LifecycleState; import org.apache.catalina.startup.Tomcat; + import static org.assertj.core.api.Assertions.assertThat; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; -import com.fasterxml.jackson.databind.ObjectMapper; - import io.modelcontextprotocol.client.McpClient; import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; import io.modelcontextprotocol.server.transport.HttpServletSseServerTransportProvider; @@ -59,7 +58,6 @@ class McpCompletionTests { public void before() { // Create and con figure the transport provider mcpServerTransportProvider = HttpServletSseServerTransportProvider.builder() - .objectMapper(new ObjectMapper()) .messageEndpoint(CUSTOM_MESSAGE_ENDPOINT) .build(); @@ -99,7 +97,7 @@ void testCompletionHandlerReceivesContext() { return new CompleteResult(new CompleteResult.CompleteCompletion(List.of("test-completion"), 1, false)); }; - ResourceReference resourceRef = new ResourceReference("ref/resource", "test://resource/{param}"); + ResourceReference resourceRef = new ResourceReference(ResourceReference.TYPE, "test://resource/{param}"); var resource = Resource.builder() .uri("test://resource/{param}") @@ -154,7 +152,7 @@ void testCompletionBackwardCompatibility() { .prompts(new McpServerFeatures.SyncPromptSpecification(prompt, (mcpSyncServerExchange, getPromptRequest) -> null)) .completions(new McpServerFeatures.SyncCompletionSpecification( - new PromptReference("ref/prompt", "test-prompt"), completionHandler)) + new PromptReference(PromptReference.TYPE, "test-prompt"), completionHandler)) .build(); try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) @@ -163,7 +161,7 @@ void testCompletionBackwardCompatibility() { assertThat(initResult).isNotNull(); // Test without context - CompleteRequest request = new CompleteRequest(new PromptReference("ref/prompt", "test-prompt"), + CompleteRequest request = new CompleteRequest(new PromptReference(PromptReference.TYPE, "test-prompt"), new CompleteRequest.CompleteArgument("arg", "val")); CompleteResult result = mcpClient.completeCompletion(request); @@ -219,7 +217,7 @@ else if ("products_db".equals(db)) { .resources(new McpServerFeatures.SyncResourceSpecification(resource, (exchange, req) -> new ReadResourceResult(List.of()))) .completions(new McpServerFeatures.SyncCompletionSpecification( - new ResourceReference("ref/resource", "db://{database}/{table}"), completionHandler)) + new ResourceReference(ResourceReference.TYPE, "db://{database}/{table}"), completionHandler)) .build(); try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) @@ -229,7 +227,7 @@ else if ("products_db".equals(db)) { // First, complete database CompleteRequest dbRequest = new CompleteRequest( - new ResourceReference("ref/resource", "db://{database}/{table}"), + new ResourceReference(ResourceReference.TYPE, "db://{database}/{table}"), new CompleteRequest.CompleteArgument("database", "")); CompleteResult dbResult = mcpClient.completeCompletion(dbRequest); @@ -237,7 +235,7 @@ else if ("products_db".equals(db)) { // Then complete table with database context CompleteRequest tableRequest = new CompleteRequest( - new ResourceReference("ref/resource", "db://{database}/{table}"), + new ResourceReference(ResourceReference.TYPE, "db://{database}/{table}"), new CompleteRequest.CompleteArgument("table", ""), new CompleteRequest.CompleteContext(Map.of("database", "users_db"))); @@ -246,7 +244,7 @@ else if ("products_db".equals(db)) { // Different database gives different tables CompleteRequest tableRequest2 = new CompleteRequest( - new ResourceReference("ref/resource", "db://{database}/{table}"), + new ResourceReference(ResourceReference.TYPE, "db://{database}/{table}"), new CompleteRequest.CompleteArgument("table", ""), new CompleteRequest.CompleteContext(Map.of("database", "products_db"))); @@ -296,7 +294,7 @@ void testCompletionErrorOnMissingContext() { .resources(new McpServerFeatures.SyncResourceSpecification(resource, (exchange, req) -> new ReadResourceResult(List.of()))) .completions(new McpServerFeatures.SyncCompletionSpecification( - new ResourceReference("ref/resource", "db://{database}/{table}"), completionHandler)) + new ResourceReference(ResourceReference.TYPE, "db://{database}/{table}"), completionHandler)) .build(); try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample" + "client", "0.0.0")) @@ -306,7 +304,7 @@ void testCompletionErrorOnMissingContext() { // Try to complete table without database context - should raise error CompleteRequest requestWithoutContext = new CompleteRequest( - new ResourceReference("ref/resource", "db://{database}/{table}"), + new ResourceReference(ResourceReference.TYPE, "db://{database}/{table}"), new CompleteRequest.CompleteArgument("table", "")); assertThatExceptionOfType(McpError.class) @@ -315,7 +313,7 @@ void testCompletionErrorOnMissingContext() { // Now complete with proper context - should work normally CompleteRequest requestWithContext = new CompleteRequest( - new ResourceReference("ref/resource", "db://{database}/{table}"), + new ResourceReference(ResourceReference.TYPE, "db://{database}/{table}"), new CompleteRequest.CompleteArgument("table", ""), new CompleteRequest.CompleteContext(Map.of("database", "test_db"))); diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/McpServerProtocolVersionTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/McpServerProtocolVersionTests.java similarity index 100% rename from mcp/src/test/java/io/modelcontextprotocol/server/McpServerProtocolVersionTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/server/McpServerProtocolVersionTests.java diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/McpSyncServerExchangeTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/McpSyncServerExchangeTests.java similarity index 96% rename from mcp/src/test/java/io/modelcontextprotocol/server/McpSyncServerExchangeTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/server/McpSyncServerExchangeTests.java index a73ec7209..069d0f896 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/McpSyncServerExchangeTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/McpSyncServerExchangeTests.java @@ -9,10 +9,10 @@ import java.util.List; import java.util.Map; -import com.fasterxml.jackson.core.type.TypeReference; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.json.TypeRef; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.Mock; @@ -66,7 +66,7 @@ void testListRootsWithSinglePage() { McpSchema.ListRootsResult singlePageResult = new McpSchema.ListRootsResult(roots, null); when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), any(McpSchema.PaginatedRequest.class), - any(TypeReference.class))) + any(TypeRef.class))) .thenReturn(Mono.just(singlePageResult)); McpSchema.ListRootsResult result = exchange.listRoots(); @@ -94,11 +94,11 @@ void testListRootsWithMultiplePages() { McpSchema.ListRootsResult page2Result = new McpSchema.ListRootsResult(page2Roots, null); when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), eq(new McpSchema.PaginatedRequest(null)), - any(TypeReference.class))) + any(TypeRef.class))) .thenReturn(Mono.just(page1Result)); when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), eq(new McpSchema.PaginatedRequest("cursor1")), - any(TypeReference.class))) + any(TypeRef.class))) .thenReturn(Mono.just(page2Result)); McpSchema.ListRootsResult result = exchange.listRoots(); @@ -120,7 +120,7 @@ void testListRootsWithEmptyResult() { McpSchema.ListRootsResult emptyResult = new McpSchema.ListRootsResult(new ArrayList<>(), null); when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), any(McpSchema.PaginatedRequest.class), - any(TypeReference.class))) + any(TypeRef.class))) .thenReturn(Mono.just(emptyResult)); McpSchema.ListRootsResult result = exchange.listRoots(); @@ -140,7 +140,7 @@ void testListRootsWithSpecificCursor() { McpSchema.ListRootsResult result = new McpSchema.ListRootsResult(roots, "nextCursor"); when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), eq(new McpSchema.PaginatedRequest("someCursor")), - any(TypeReference.class))) + any(TypeRef.class))) .thenReturn(Mono.just(result)); McpSchema.ListRootsResult listResult = exchange.listRoots("someCursor"); @@ -154,7 +154,7 @@ void testListRootsWithSpecificCursor() { void testListRootsWithError() { when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), any(McpSchema.PaginatedRequest.class), - any(TypeReference.class))) + any(TypeRef.class))) .thenReturn(Mono.error(new RuntimeException("Network error"))); // When & Then @@ -173,11 +173,11 @@ void testListRootsUnmodifiabilityAfterAccumulation() { McpSchema.ListRootsResult page2Result = new McpSchema.ListRootsResult(page2Roots, null); when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), eq(new McpSchema.PaginatedRequest(null)), - any(TypeReference.class))) + any(TypeRef.class))) .thenReturn(Mono.just(page1Result)); when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), eq(new McpSchema.PaginatedRequest("cursor1")), - any(TypeReference.class))) + any(TypeRef.class))) .thenReturn(Mono.just(page2Result)); McpSchema.ListRootsResult result = exchange.listRoots(); @@ -308,8 +308,7 @@ void testCreateElicitationWithNullCapabilities() { .hasMessage("Client must be initialized. Call the initialize method first!"); // Verify that sendRequest was never called due to null capabilities - verify(mockSession, never()).sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), any(), - any(TypeReference.class)); + verify(mockSession, never()).sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), any(), any(TypeRef.class)); } @Test @@ -333,8 +332,7 @@ void testCreateElicitationWithoutElicitationCapabilities() { // Verify that sendRequest was never called due to missing elicitation // capabilities - verify(mockSession, never()).sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), any(), - any(TypeReference.class)); + verify(mockSession, never()).sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), any(), any(TypeRef.class)); } @Test @@ -369,8 +367,7 @@ void testCreateElicitationWithComplexRequest() { .content(responseContent) .build(); - when(mockSession.sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), eq(elicitRequest), - any(TypeReference.class))) + when(mockSession.sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), eq(elicitRequest), any(TypeRef.class))) .thenReturn(Mono.just(expectedResult)); McpSchema.ElicitResult result = exchangeWithElicitation.createElicitation(elicitRequest); @@ -401,8 +398,7 @@ void testCreateElicitationWithDeclineAction() { .message(McpSchema.ElicitResult.Action.DECLINE) .build(); - when(mockSession.sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), eq(elicitRequest), - any(TypeReference.class))) + when(mockSession.sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), eq(elicitRequest), any(TypeRef.class))) .thenReturn(Mono.just(expectedResult)); McpSchema.ElicitResult result = exchangeWithElicitation.createElicitation(elicitRequest); @@ -430,8 +426,7 @@ void testCreateElicitationWithCancelAction() { .message(McpSchema.ElicitResult.Action.CANCEL) .build(); - when(mockSession.sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), eq(elicitRequest), - any(TypeReference.class))) + when(mockSession.sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), eq(elicitRequest), any(TypeRef.class))) .thenReturn(Mono.just(expectedResult)); McpSchema.ElicitResult result = exchangeWithElicitation.createElicitation(elicitRequest); @@ -455,8 +450,7 @@ void testCreateElicitationWithSessionError() { .message("Please provide your name") .build(); - when(mockSession.sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), eq(elicitRequest), - any(TypeReference.class))) + when(mockSession.sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), eq(elicitRequest), any(TypeRef.class))) .thenReturn(Mono.error(new RuntimeException("Session communication error"))); assertThatThrownBy(() -> exchangeWithElicitation.createElicitation(elicitRequest)) @@ -487,7 +481,7 @@ void testCreateMessageWithNullCapabilities() { // Verify that sendRequest was never called due to null capabilities verify(mockSession, never()).sendRequest(eq(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE), any(), - any(TypeReference.class)); + any(TypeRef.class)); } @Test @@ -512,7 +506,7 @@ void testCreateMessageWithoutSamplingCapabilities() { // Verify that sendRequest was never called due to missing sampling capabilities verify(mockSession, never()).sendRequest(eq(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE), any(), - any(TypeReference.class)); + any(TypeRef.class)); } @Test @@ -539,7 +533,7 @@ void testCreateMessageWithBasicRequest() { .build(); when(mockSession.sendRequest(eq(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE), eq(createMessageRequest), - any(TypeReference.class))) + any(TypeRef.class))) .thenReturn(Mono.just(expectedResult)); McpSchema.CreateMessageResult result = exchangeWithSampling.createMessage(createMessageRequest); @@ -578,7 +572,7 @@ void testCreateMessageWithImageContent() { .build(); when(mockSession.sendRequest(eq(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE), eq(createMessageRequest), - any(TypeReference.class))) + any(TypeRef.class))) .thenReturn(Mono.just(expectedResult)); McpSchema.CreateMessageResult result = exchangeWithSampling.createMessage(createMessageRequest); @@ -605,7 +599,7 @@ void testCreateMessageWithSessionError() { .build(); when(mockSession.sendRequest(eq(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE), eq(createMessageRequest), - any(TypeReference.class))) + any(TypeRef.class))) .thenReturn(Mono.error(new RuntimeException("Session communication error"))); assertThatThrownBy(() -> exchangeWithSampling.createMessage(createMessageRequest)) @@ -638,7 +632,7 @@ void testCreateMessageWithIncludeContext() { .build(); when(mockSession.sendRequest(eq(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE), eq(createMessageRequest), - any(TypeReference.class))) + any(TypeRef.class))) .thenReturn(Mono.just(expectedResult)); McpSchema.CreateMessageResult result = exchangeWithSampling.createMessage(createMessageRequest); @@ -656,32 +650,32 @@ void testPingWithSuccessfulResponse() { java.util.Map expectedResponse = java.util.Map.of(); - when(mockSession.sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeReference.class))) + when(mockSession.sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeRef.class))) .thenReturn(Mono.just(expectedResponse)); exchange.ping(); // Verify that sendRequest was called with correct parameters - verify(mockSession, times(1)).sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeReference.class)); + verify(mockSession, times(1)).sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeRef.class)); } @Test void testPingWithMcpError() { // Given - Mock an MCP-specific error during ping McpError mcpError = new McpError("Server unavailable"); - when(mockSession.sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeReference.class))) + when(mockSession.sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeRef.class))) .thenReturn(Mono.error(mcpError)); // When & Then assertThatThrownBy(() -> exchange.ping()).isInstanceOf(McpError.class).hasMessage("Server unavailable"); - verify(mockSession, times(1)).sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeReference.class)); + verify(mockSession, times(1)).sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeRef.class)); } @Test void testPingMultipleCalls() { - when(mockSession.sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeReference.class))) + when(mockSession.sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeRef.class))) .thenReturn(Mono.just(Map.of())) .thenReturn(Mono.just(Map.of())); @@ -692,7 +686,7 @@ void testPingMultipleCalls() { exchange.ping(); // Verify that sendRequest was called twice - verify(mockSession, times(2)).sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeReference.class)); + verify(mockSession, times(2)).sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeRef.class)); } } diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/ResourceTemplateListingTest.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/ResourceTemplateListingTest.java new file mode 100644 index 000000000..61703c306 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/ResourceTemplateListingTest.java @@ -0,0 +1,100 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import io.modelcontextprotocol.spec.McpSchema; +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.stream.Collectors; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Test to verify the separation of regular resources and resource templates. Regular + * resources (without template parameters) should only appear in resources/list. Template + * resources (containing {}) should only appear in resources/templates/list. + */ +public class ResourceTemplateListingTest { + + @Test + void testTemplateResourcesFilteredFromRegularListing() { + // The change we made filters resources containing "{" from the regular listing + // This test verifies that behavior is working correctly + + // Given a string with template parameter + String templateUri = "file:///test/{userId}/profile.txt"; + assertThat(templateUri.contains("{")).isTrue(); + + // And a regular URI + String regularUri = "file:///test/regular.txt"; + assertThat(regularUri.contains("{")).isFalse(); + + // The filter should exclude template URIs + assertThat(!templateUri.contains("{")).isFalse(); + assertThat(!regularUri.contains("{")).isTrue(); + } + + @Test + void testResourceListingWithMixedResources() { + // Create resource list with both regular and template resources + List allResources = List.of( + new McpSchema.Resource("file:///test/doc1.txt", "Document 1", "text/plain", null, null), + new McpSchema.Resource("file:///test/doc2.txt", "Document 2", "text/plain", null, null), + new McpSchema.Resource("file:///test/{type}/document.txt", "Typed Document", "text/plain", null, null), + new McpSchema.Resource("file:///users/{userId}/files/{fileId}", "User File", "text/plain", null, null)); + + // Apply the filter logic from McpAsyncServer line 438 + List filteredResources = allResources.stream() + .filter(resource -> !resource.uri().contains("{")) + .collect(Collectors.toList()); + + // Verify only regular resources are included + assertThat(filteredResources).hasSize(2); + assertThat(filteredResources).extracting(McpSchema.Resource::uri) + .containsExactlyInAnyOrder("file:///test/doc1.txt", "file:///test/doc2.txt"); + } + + @Test + void testResourceTemplatesListedSeparately() { + // Create mixed resources + List resources = List.of( + new McpSchema.Resource("file:///test/regular.txt", "Regular Resource", "text/plain", null, null), + new McpSchema.Resource("file:///test/user/{userId}/profile.txt", "User Profile", "text/plain", null, + null)); + + // Create explicit resource template + McpSchema.ResourceTemplate explicitTemplate = new McpSchema.ResourceTemplate( + "file:///test/document/{docId}/content.txt", "Document Template", null, "text/plain", null); + + // Filter regular resources (those without template parameters) + List regularResources = resources.stream() + .filter(resource -> !resource.uri().contains("{")) + .collect(Collectors.toList()); + + // Extract template resources (those with template parameters) + List templateResources = resources.stream() + .filter(resource -> resource.uri().contains("{")) + .map(resource -> new McpSchema.ResourceTemplate(resource.uri(), resource.name(), resource.description(), + resource.mimeType(), resource.annotations())) + .collect(Collectors.toList()); + + // Verify regular resources list + assertThat(regularResources).hasSize(1); + assertThat(regularResources.get(0).uri()).isEqualTo("file:///test/regular.txt"); + + // Verify template resources list includes both extracted and explicit templates + assertThat(templateResources).hasSize(1); + assertThat(templateResources.get(0).uriTemplate()).isEqualTo("file:///test/user/{userId}/profile.txt"); + + // In the actual implementation, both would be combined + List allTemplates = List.of(templateResources.get(0), explicitTemplate); + assertThat(allTemplates).hasSize(2); + assertThat(allTemplates).extracting(McpSchema.ResourceTemplate::uriTemplate) + .containsExactlyInAnyOrder("file:///test/user/{userId}/profile.txt", + "file:///test/document/{docId}/content.txt"); + } + +} \ No newline at end of file diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/ResourceTemplateManagementTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/ResourceTemplateManagementTests.java new file mode 100644 index 000000000..b7d46a967 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/ResourceTemplateManagementTests.java @@ -0,0 +1,299 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import java.time.Duration; +import java.util.List; + +import io.modelcontextprotocol.MockMcpServerTransport; +import io.modelcontextprotocol.MockMcpServerTransportProvider; +import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; +import io.modelcontextprotocol.spec.McpSchema.ResourceTemplate; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; + +/** + * Test suite for Resource Template Management functionality. Tests the new + * addResourceTemplate() and removeResourceTemplate() methods, as well as the Map-based + * resource template storage. + * + * @author Christian Tzolov + */ +public class ResourceTemplateManagementTests { + + private static final String TEST_TEMPLATE_URI = "test://resource/{param}"; + + private static final String TEST_TEMPLATE_NAME = "test-template"; + + private MockMcpServerTransportProvider mockTransportProvider; + + private McpAsyncServer mcpAsyncServer; + + @BeforeEach + void setUp() { + mockTransportProvider = new MockMcpServerTransportProvider(new MockMcpServerTransport()); + } + + @AfterEach + void tearDown() { + if (mcpAsyncServer != null) { + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))) + .doesNotThrowAnyException(); + } + } + + // --------------------------------------- + // Async Resource Template Tests + // --------------------------------------- + + @Test + void testAddResourceTemplate() { + mcpAsyncServer = McpServer.async(mockTransportProvider) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + ResourceTemplate template = ResourceTemplate.builder() + .uriTemplate(TEST_TEMPLATE_URI) + .name(TEST_TEMPLATE_NAME) + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.AsyncResourceTemplateSpecification specification = new McpServerFeatures.AsyncResourceTemplateSpecification( + template, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + StepVerifier.create(mcpAsyncServer.addResourceTemplate(specification)).verifyComplete(); + } + + @Test + void testAddResourceTemplateWithoutCapability() { + // Create a server without resource capabilities + McpAsyncServer serverWithoutResources = McpServer.async(mockTransportProvider) + .serverInfo("test-server", "1.0.0") + .build(); + + ResourceTemplate template = ResourceTemplate.builder() + .uriTemplate(TEST_TEMPLATE_URI) + .name(TEST_TEMPLATE_NAME) + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.AsyncResourceTemplateSpecification specification = new McpServerFeatures.AsyncResourceTemplateSpecification( + template, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + StepVerifier.create(serverWithoutResources.addResourceTemplate(specification)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Server must be configured with resource capabilities"); + }); + + assertThatCode(() -> serverWithoutResources.closeGracefully().block(Duration.ofSeconds(10))) + .doesNotThrowAnyException(); + } + + @Test + void testRemoveResourceTemplate() { + ResourceTemplate template = ResourceTemplate.builder() + .uriTemplate(TEST_TEMPLATE_URI) + .name(TEST_TEMPLATE_NAME) + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.AsyncResourceTemplateSpecification specification = new McpServerFeatures.AsyncResourceTemplateSpecification( + template, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + mcpAsyncServer = McpServer.async(mockTransportProvider) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .resourceTemplates(specification) + .build(); + + StepVerifier.create(mcpAsyncServer.removeResourceTemplate(TEST_TEMPLATE_URI)).verifyComplete(); + } + + @Test + void testRemoveResourceTemplateWithoutCapability() { + // Create a server without resource capabilities + McpAsyncServer serverWithoutResources = McpServer.async(mockTransportProvider) + .serverInfo("test-server", "1.0.0") + .build(); + + StepVerifier.create(serverWithoutResources.removeResourceTemplate(TEST_TEMPLATE_URI)) + .verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Server must be configured with resource capabilities"); + }); + + assertThatCode(() -> serverWithoutResources.closeGracefully().block(Duration.ofSeconds(10))) + .doesNotThrowAnyException(); + } + + @Test + void testRemoveNonexistentResourceTemplate() { + mcpAsyncServer = McpServer.async(mockTransportProvider) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + // Removing a non-existent resource template should complete successfully (no + // error) + // as per the new implementation that just logs a warning + StepVerifier.create(mcpAsyncServer.removeResourceTemplate("nonexistent://template/{id}")).verifyComplete(); + } + + @Test + void testReplaceExistingResourceTemplate() { + ResourceTemplate originalTemplate = ResourceTemplate.builder() + .uriTemplate(TEST_TEMPLATE_URI) + .name(TEST_TEMPLATE_NAME) + .description("Original template") + .mimeType("text/plain") + .build(); + + ResourceTemplate updatedTemplate = ResourceTemplate.builder() + .uriTemplate(TEST_TEMPLATE_URI) + .name(TEST_TEMPLATE_NAME) + .description("Updated template") + .mimeType("application/json") + .build(); + + McpServerFeatures.AsyncResourceTemplateSpecification originalSpec = new McpServerFeatures.AsyncResourceTemplateSpecification( + originalTemplate, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + McpServerFeatures.AsyncResourceTemplateSpecification updatedSpec = new McpServerFeatures.AsyncResourceTemplateSpecification( + updatedTemplate, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + mcpAsyncServer = McpServer.async(mockTransportProvider) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .resourceTemplates(originalSpec) + .build(); + + // Adding a resource template with the same URI should replace the existing one + StepVerifier.create(mcpAsyncServer.addResourceTemplate(updatedSpec)).verifyComplete(); + } + + // --------------------------------------- + // Sync Resource Template Tests + // --------------------------------------- + + @Test + void testSyncAddResourceTemplate() { + ResourceTemplate template = ResourceTemplate.builder() + .uriTemplate(TEST_TEMPLATE_URI) + .name(TEST_TEMPLATE_NAME) + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.SyncResourceTemplateSpecification specification = new McpServerFeatures.SyncResourceTemplateSpecification( + template, (exchange, req) -> new ReadResourceResult(List.of())); + + var mcpSyncServer = McpServer.sync(mockTransportProvider) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + assertThatCode(() -> mcpSyncServer.addResourceTemplate(specification)).doesNotThrowAnyException(); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + @Test + void testSyncRemoveResourceTemplate() { + ResourceTemplate template = ResourceTemplate.builder() + .uriTemplate(TEST_TEMPLATE_URI) + .name(TEST_TEMPLATE_NAME) + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.SyncResourceTemplateSpecification specification = new McpServerFeatures.SyncResourceTemplateSpecification( + template, (exchange, req) -> new ReadResourceResult(List.of())); + + var mcpSyncServer = McpServer.sync(mockTransportProvider) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .resourceTemplates(specification) + .build(); + + assertThatCode(() -> mcpSyncServer.removeResourceTemplate(TEST_TEMPLATE_URI)).doesNotThrowAnyException(); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Map-based Storage Tests + // --------------------------------------- + + @Test + void testResourceTemplateMapBasedStorage() { + ResourceTemplate template1 = ResourceTemplate.builder() + .uriTemplate("test://template1/{id}") + .name("template1") + .description("First template") + .mimeType("text/plain") + .build(); + + ResourceTemplate template2 = ResourceTemplate.builder() + .uriTemplate("test://template2/{id}") + .name("template2") + .description("Second template") + .mimeType("application/json") + .build(); + + McpServerFeatures.AsyncResourceTemplateSpecification spec1 = new McpServerFeatures.AsyncResourceTemplateSpecification( + template1, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + McpServerFeatures.AsyncResourceTemplateSpecification spec2 = new McpServerFeatures.AsyncResourceTemplateSpecification( + template2, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + mcpAsyncServer = McpServer.async(mockTransportProvider) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .resourceTemplates(spec1, spec2) + .build(); + + // Verify both templates are stored (this would be tested through integration + // tests + // or by accessing internal state, but for unit tests we verify no exceptions) + assertThat(mcpAsyncServer).isNotNull(); + } + + @Test + void testResourceTemplateBuilderWithMap() { + // Test that the new Map-based builder methods work correctly + ResourceTemplate template = ResourceTemplate.builder() + .uriTemplate(TEST_TEMPLATE_URI) + .name(TEST_TEMPLATE_NAME) + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.AsyncResourceTemplateSpecification specification = new McpServerFeatures.AsyncResourceTemplateSpecification( + template, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + // Test varargs builder method + assertThatCode(() -> { + McpServer.async(mockTransportProvider) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .resourceTemplates(specification) + .build() + .closeGracefully() + .block(Duration.ofSeconds(10)); + }).doesNotThrowAnyException(); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpAsyncServerTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/ServletSseMcpAsyncServerTests.java similarity index 100% rename from mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpAsyncServerTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/server/ServletSseMcpAsyncServerTests.java diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpSyncServerTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/ServletSseMcpSyncServerTests.java similarity index 100% rename from mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpSyncServerTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/server/ServletSseMcpSyncServerTests.java diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpAsyncServerTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/StdioMcpAsyncServerTests.java similarity index 84% rename from mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpAsyncServerTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/server/StdioMcpAsyncServerTests.java index 97db5fa06..b2dfbea25 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpAsyncServerTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/StdioMcpAsyncServerTests.java @@ -8,6 +8,8 @@ import io.modelcontextprotocol.spec.McpServerTransportProvider; import org.junit.jupiter.api.Timeout; +import static io.modelcontextprotocol.util.McpJsonMapperUtils.JSON_MAPPER; + /** * Tests for {@link McpAsyncServer} using {@link StdioServerTransport}. * @@ -17,7 +19,7 @@ class StdioMcpAsyncServerTests extends AbstractMcpAsyncServerTests { protected McpServerTransportProvider createMcpTransportProvider() { - return new StdioServerTransportProvider(); + return new StdioServerTransportProvider(JSON_MAPPER); } @Override diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpSyncServerTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/StdioMcpSyncServerTests.java similarity index 84% rename from mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpSyncServerTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/server/StdioMcpSyncServerTests.java index 1e01962e9..c97c75d38 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpSyncServerTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/StdioMcpSyncServerTests.java @@ -8,6 +8,8 @@ import io.modelcontextprotocol.spec.McpServerTransportProvider; import org.junit.jupiter.api.Timeout; +import static io.modelcontextprotocol.util.McpJsonMapperUtils.JSON_MAPPER; + /** * Tests for {@link McpSyncServer} using {@link StdioServerTransportProvider}. * @@ -17,7 +19,7 @@ class StdioMcpSyncServerTests extends AbstractMcpSyncServerTests { protected McpServerTransportProvider createMcpTransportProvider() { - return new StdioServerTransportProvider(); + return new StdioServerTransportProvider(JSON_MAPPER); } @Override diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/SyncToolSpecificationBuilderTest.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/SyncToolSpecificationBuilderTest.java similarity index 71% rename from mcp/src/test/java/io/modelcontextprotocol/server/SyncToolSpecificationBuilderTest.java rename to mcp-core/src/test/java/io/modelcontextprotocol/server/SyncToolSpecificationBuilderTest.java index 4aac46952..9bcd2bc84 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/SyncToolSpecificationBuilderTest.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/SyncToolSpecificationBuilderTest.java @@ -4,6 +4,7 @@ package io.modelcontextprotocol.server; +import static io.modelcontextprotocol.util.ToolsUtils.EMPTY_JSON_SCHEMA; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -24,20 +25,17 @@ */ class SyncToolSpecificationBuilderTest { - String emptyJsonSchema = """ - { - "type": "object" - } - """; - @Test void builderShouldCreateValidSyncToolSpecification() { - Tool tool = new Tool("test-tool", "A test tool", emptyJsonSchema); + Tool tool = Tool.builder().name("test-tool").title("A test tool").inputSchema(EMPTY_JSON_SCHEMA).build(); McpServerFeatures.SyncToolSpecification specification = McpServerFeatures.SyncToolSpecification.builder() .tool(tool) - .callHandler((exchange, request) -> new CallToolResult(List.of(new TextContent("Test result")), false)) + .callHandler((exchange, request) -> CallToolResult.builder() + .content(List.of(new TextContent("Test result"))) + .isError(false) + .build()) .build(); assertThat(specification).isNotNull(); @@ -49,13 +47,13 @@ void builderShouldCreateValidSyncToolSpecification() { @Test void builderShouldThrowExceptionWhenToolIsNull() { assertThatThrownBy(() -> McpServerFeatures.SyncToolSpecification.builder() - .callHandler((exchange, request) -> new CallToolResult(List.of(), false)) + .callHandler((exchange, request) -> CallToolResult.builder().content(List.of()).isError(false).build()) .build()).isInstanceOf(IllegalArgumentException.class).hasMessage("Tool must not be null"); } @Test void builderShouldThrowExceptionWhenCallToolIsNull() { - Tool tool = new Tool("test-tool", "A test tool", emptyJsonSchema); + Tool tool = Tool.builder().name("test-tool").description("A test tool").inputSchema(EMPTY_JSON_SCHEMA).build(); assertThatThrownBy(() -> McpServerFeatures.SyncToolSpecification.builder().tool(tool).build()) .isInstanceOf(IllegalArgumentException.class) @@ -64,24 +62,33 @@ void builderShouldThrowExceptionWhenCallToolIsNull() { @Test void builderShouldAllowMethodChaining() { - Tool tool = new Tool("test-tool", "A test tool", emptyJsonSchema); + Tool tool = Tool.builder().name("test-tool").description("A test tool").inputSchema(EMPTY_JSON_SCHEMA).build(); McpServerFeatures.SyncToolSpecification.Builder builder = McpServerFeatures.SyncToolSpecification.builder(); // Then - verify method chaining returns the same builder instance assertThat(builder.tool(tool)).isSameAs(builder); - assertThat(builder.callHandler((exchange, request) -> new CallToolResult(List.of(), false))).isSameAs(builder); + assertThat(builder + .callHandler((exchange, request) -> CallToolResult.builder().content(List.of()).isError(false).build())) + .isSameAs(builder); } @Test void builtSpecificationShouldExecuteCallToolCorrectly() { - Tool tool = new Tool("calculator", "Simple calculator", emptyJsonSchema); + Tool tool = Tool.builder() + .name("calculator") + .description("Simple calculator") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); String expectedResult = "42"; McpServerFeatures.SyncToolSpecification specification = McpServerFeatures.SyncToolSpecification.builder() .tool(tool) .callHandler((exchange, request) -> { // Simple test implementation - return new CallToolResult(List.of(new TextContent(expectedResult)), false); + return CallToolResult.builder() + .content(List.of(new TextContent(expectedResult))) + .isError(false) + .build(); }) .build(); diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerCustomContextPathTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerCustomContextPathTests.java similarity index 96% rename from mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerCustomContextPathTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerCustomContextPathTests.java index 0462cbafe..be88097b3 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerCustomContextPathTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerCustomContextPathTests.java @@ -4,8 +4,6 @@ package io.modelcontextprotocol.server.transport; -import com.fasterxml.jackson.databind.ObjectMapper; - import io.modelcontextprotocol.client.McpClient; import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; import io.modelcontextprotocol.server.McpServer; @@ -40,7 +38,6 @@ public void before() { // Create and configure the transport provider mcpServerTransportProvider = HttpServletSseServerTransportProvider.builder() - .objectMapper(new ObjectMapper()) .baseUrl(CUSTOM_CONTEXT_PATH) .messageEndpoint(CUSTOM_MESSAGE_ENDPOINT) .sseEndpoint(CUSTOM_SSE_ENDPOINT) diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/McpTestRequestRecordingServletFilter.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/McpTestRequestRecordingServletFilter.java new file mode 100644 index 000000000..b94552d12 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/McpTestRequestRecordingServletFilter.java @@ -0,0 +1,128 @@ +/* + * Copyright 2025 - 2025 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import java.io.BufferedReader; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.function.Function; +import java.util.stream.Collectors; + +import jakarta.servlet.Filter; +import jakarta.servlet.FilterChain; +import jakarta.servlet.ReadListener; +import jakarta.servlet.ServletException; +import jakarta.servlet.ServletInputStream; +import jakarta.servlet.ServletRequest; +import jakarta.servlet.ServletResponse; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletRequestWrapper; + +/** + * Simple {@link Filter} which records calls made to an MCP server. + * + * @author Daniel Garnier-Moiroux + */ +public class McpTestRequestRecordingServletFilter implements Filter { + + private final List calls = new ArrayList<>(); + + @Override + public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) + throws IOException, ServletException { + + if (servletRequest instanceof HttpServletRequest req) { + var headers = Collections.list(req.getHeaderNames()) + .stream() + .collect(Collectors.toUnmodifiableMap(Function.identity(), + name -> String.join(",", Collections.list(req.getHeaders(name))))); + var request = new CachedBodyHttpServletRequest(req); + calls.add(new Call(req.getMethod(), headers, request.getBodyAsString())); + filterChain.doFilter(request, servletResponse); + } + else { + filterChain.doFilter(servletRequest, servletResponse); + } + + } + + public List getCalls() { + + return List.copyOf(calls); + } + + public record Call(String method, Map headers, String body) { + + } + + public static class CachedBodyHttpServletRequest extends HttpServletRequestWrapper { + + private final byte[] cachedBody; + + public CachedBodyHttpServletRequest(HttpServletRequest request) throws IOException { + super(request); + this.cachedBody = request.getInputStream().readAllBytes(); + } + + @Override + public ServletInputStream getInputStream() { + return new CachedBodyServletInputStream(cachedBody); + } + + @Override + public BufferedReader getReader() { + return new BufferedReader(new InputStreamReader(getInputStream(), StandardCharsets.UTF_8)); + } + + public String getBodyAsString() { + return new String(cachedBody, StandardCharsets.UTF_8); + } + + } + + public static class CachedBodyServletInputStream extends ServletInputStream { + + private InputStream cachedBodyInputStream; + + public CachedBodyServletInputStream(byte[] cachedBody) { + this.cachedBodyInputStream = new ByteArrayInputStream(cachedBody); + } + + @Override + public boolean isFinished() { + try { + return cachedBodyInputStream.available() == 0; + } + catch (IOException e) { + e.printStackTrace(); + } + return false; + } + + @Override + public boolean isReady() { + return true; + } + + @Override + public void setReadListener(ReadListener readListener) { + throw new UnsupportedOperationException(); + } + + @Override + public int read() throws IOException { + return cachedBodyInputStream.read(); + } + + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java similarity index 93% rename from mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java index 14987b5ac..6a70af33d 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java @@ -14,7 +14,6 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; -import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpServerSession; @@ -26,6 +25,7 @@ import reactor.core.publisher.Mono; import reactor.test.StepVerifier; +import static io.modelcontextprotocol.util.McpJsonMapperUtils.JSON_MAPPER; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; @@ -50,8 +50,6 @@ class StdioServerTransportProviderTests { private StdioServerTransportProvider transportProvider; - private ObjectMapper objectMapper; - private McpServerSession.Factory sessionFactory; private McpServerSession mockSession; @@ -64,8 +62,6 @@ void setUp() { System.setOut(testOutPrintStream); System.setErr(testOutPrintStream); - objectMapper = new ObjectMapper(); - // Create mocks for session factory and session mockSession = mock(McpServerSession.class); sessionFactory = mock(McpServerSession.Factory.class); @@ -75,7 +71,7 @@ void setUp() { when(mockSession.closeGracefully()).thenReturn(Mono.empty()); when(mockSession.sendNotification(any(), any())).thenReturn(Mono.empty()); - transportProvider = new StdioServerTransportProvider(objectMapper, System.in, testOutPrintStream); + transportProvider = new StdioServerTransportProvider(JSON_MAPPER, System.in, testOutPrintStream); } @AfterEach @@ -105,7 +101,7 @@ void shouldHandleIncomingMessages() throws Exception { String jsonMessage = "{\"jsonrpc\":\"2.0\",\"method\":\"test\",\"params\":{},\"id\":1}\n"; InputStream stream = new ByteArrayInputStream(jsonMessage.getBytes(StandardCharsets.UTF_8)); - transportProvider = new StdioServerTransportProvider(objectMapper, stream, System.out); + transportProvider = new StdioServerTransportProvider(JSON_MAPPER, stream, System.out); // Set up a real session to capture the message AtomicReference capturedMessage = new AtomicReference<>(); CountDownLatch messageLatch = new CountDownLatch(1); @@ -185,7 +181,7 @@ void shouldHandleMultipleCloseGracefullyCalls() { @Test void shouldHandleNotificationBeforeSessionFactoryIsSet() { - transportProvider = new StdioServerTransportProvider(objectMapper); + transportProvider = new StdioServerTransportProvider(JSON_MAPPER); // Send notification before setting session factory StepVerifier.create(transportProvider.notifyClients("testNotification", Map.of("key", "value"))) .verifyErrorSatisfies(error -> { @@ -200,7 +196,7 @@ void shouldHandleInvalidJsonMessage() throws Exception { String jsonMessage = "{invalid json}\n"; InputStream stream = new ByteArrayInputStream(jsonMessage.getBytes(StandardCharsets.UTF_8)); - transportProvider = new StdioServerTransportProvider(objectMapper, stream, testOutPrintStream); + transportProvider = new StdioServerTransportProvider(JSON_MAPPER, stream, testOutPrintStream); // Set up a session factory transportProvider.setSessionFactory(sessionFactory); diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/TomcatTestUtil.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/TomcatTestUtil.java similarity index 76% rename from mcp/src/test/java/io/modelcontextprotocol/server/transport/TomcatTestUtil.java rename to mcp-core/src/test/java/io/modelcontextprotocol/server/transport/TomcatTestUtil.java index 2cf95dc94..490e29838 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/TomcatTestUtil.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/TomcatTestUtil.java @@ -8,6 +8,7 @@ import java.net.InetSocketAddress; import java.net.ServerSocket; +import jakarta.servlet.Filter; import jakarta.servlet.Servlet; import org.apache.catalina.Context; import org.apache.catalina.startup.Tomcat; @@ -24,7 +25,8 @@ public class TomcatTestUtil { // Prevent instantiation } - public static Tomcat createTomcatServer(String contextPath, int port, Servlet servlet) { + public static Tomcat createTomcatServer(String contextPath, int port, Servlet servlet, + Filter... additionalFilters) { var tomcat = new Tomcat(); tomcat.setPort(port); @@ -43,15 +45,17 @@ public static Tomcat createTomcatServer(String contextPath, int port, Servlet se context.addChild(wrapper); context.addServletMappingDecoded("/*", "mcpServlet"); - var filterDef = new FilterDef(); - filterDef.setFilterClass(McpTestServletFilter.class.getName()); - filterDef.setFilterName(McpTestServletFilter.class.getSimpleName()); - context.addFilterDef(filterDef); + for (var filter : additionalFilters) { + var filterDef = new FilterDef(); + filterDef.setFilter(filter); + filterDef.setFilterName(McpTestRequestRecordingServletFilter.class.getSimpleName()); + context.addFilterDef(filterDef); - var filterMap = new FilterMap(); - filterMap.setFilterName(McpTestServletFilter.class.getSimpleName()); - filterMap.addURLPattern("/*"); - context.addFilterMap(filterMap); + var filterMap = new FilterMap(); + filterMap.setFilterName(McpTestRequestRecordingServletFilter.class.getSimpleName()); + filterMap.addURLPattern("/*"); + context.addFilterMap(filterMap); + } var connector = tomcat.getConnector(); connector.setAsyncTimeout(3000); diff --git a/mcp/src/test/java/io/modelcontextprotocol/spec/ArgumentException.java b/mcp-core/src/test/java/io/modelcontextprotocol/spec/ArgumentException.java similarity index 100% rename from mcp/src/test/java/io/modelcontextprotocol/spec/ArgumentException.java rename to mcp-core/src/test/java/io/modelcontextprotocol/spec/ArgumentException.java diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/spec/CompleteCompletionSerializationTest.java b/mcp-core/src/test/java/io/modelcontextprotocol/spec/CompleteCompletionSerializationTest.java new file mode 100644 index 000000000..55f71fea4 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/spec/CompleteCompletionSerializationTest.java @@ -0,0 +1,28 @@ +package io.modelcontextprotocol.spec; + +import io.modelcontextprotocol.json.McpJsonMapper; +import org.junit.jupiter.api.Test; +import java.io.IOException; +import java.util.Collections; +import static org.junit.jupiter.api.Assertions.assertEquals; + +class CompleteCompletionSerializationTest { + + @Test + void codeCompletionSerialization() throws IOException { + McpJsonMapper jsonMapper = McpJsonMapper.getDefault(); + McpSchema.CompleteResult.CompleteCompletion codeComplete = new McpSchema.CompleteResult.CompleteCompletion( + Collections.emptyList(), 0, false); + String json = jsonMapper.writeValueAsString(codeComplete); + String expected = """ + {"values":[],"total":0,"hasMore":false}"""; + assertEquals(expected, json, json); + + McpSchema.CompleteResult completeResult = new McpSchema.CompleteResult(codeComplete); + json = jsonMapper.writeValueAsString(completeResult); + expected = """ + {"completion":{"values":[],"total":0,"hasMore":false}}"""; + assertEquals(expected, json, json); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/spec/JSONRPCRequestMcpValidationTest.java b/mcp-core/src/test/java/io/modelcontextprotocol/spec/JSONRPCRequestMcpValidationTest.java similarity index 91% rename from mcp/src/test/java/io/modelcontextprotocol/spec/JSONRPCRequestMcpValidationTest.java rename to mcp-core/src/test/java/io/modelcontextprotocol/spec/JSONRPCRequestMcpValidationTest.java index d03a6926d..fbe17d464 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/spec/JSONRPCRequestMcpValidationTest.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/spec/JSONRPCRequestMcpValidationTest.java @@ -5,7 +5,10 @@ package io.modelcontextprotocol.spec; import org.junit.jupiter.api.Test; -import static org.junit.jupiter.api.Assertions.*; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; /** * Tests for MCP-specific validation of JSONRPCRequest ID requirements. diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java new file mode 100644 index 000000000..3de06f503 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java @@ -0,0 +1,313 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.spec; + +import java.time.Duration; +import java.util.Map; +import java.util.function.Function; + +import io.modelcontextprotocol.MockMcpClientTransport; +import io.modelcontextprotocol.json.TypeRef; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.test.StepVerifier; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Test suite for {@link McpClientSession} that verifies its JSON-RPC message handling, + * request-response correlation, and notification processing. + * + * @author Christian Tzolov + */ +class McpClientSessionTests { + + private static final Logger logger = LoggerFactory.getLogger(McpClientSessionTests.class); + + private static final Duration TIMEOUT = Duration.ofSeconds(5); + + private static final String TEST_METHOD = "test.method"; + + private static final String TEST_NOTIFICATION = "test.notification"; + + private static final String ECHO_METHOD = "echo"; + + TypeRef responseType = new TypeRef<>() { + }; + + @Test + void testSendRequest() { + String testParam = "test parameter"; + String responseData = "test response"; + + var transport = new MockMcpClientTransport(); + var session = new McpClientSession(TIMEOUT, transport, Map.of(), + Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> logger.info("Status update: {}", params))), + Function.identity()); + + // Create a Mono that will emit the response after the request is sent + Mono responseMono = session.sendRequest(TEST_METHOD, testParam, responseType); + // Verify response handling + StepVerifier.create(responseMono).then(() -> { + McpSchema.JSONRPCRequest request = transport.getLastSentMessageAsRequest(); + transport.simulateIncomingMessage( + new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), responseData, null)); + }).consumeNextWith(response -> { + // Verify the request was sent + McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessageAsRequest(); + assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCRequest.class); + McpSchema.JSONRPCRequest request = (McpSchema.JSONRPCRequest) sentMessage; + assertThat(request.method()).isEqualTo(TEST_METHOD); + assertThat(request.params()).isEqualTo(testParam); + assertThat(response).isEqualTo(responseData); + }).verifyComplete(); + + session.close(); + } + + @Test + void testSendRequestWithError() { + var transport = new MockMcpClientTransport(); + var session = new McpClientSession(TIMEOUT, transport, Map.of(), + Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> logger.info("Status update: {}", params))), + Function.identity()); + + Mono responseMono = session.sendRequest(TEST_METHOD, "test", responseType); + + // Verify error handling + StepVerifier.create(responseMono).then(() -> { + McpSchema.JSONRPCRequest request = transport.getLastSentMessageAsRequest(); + // Simulate error response + McpSchema.JSONRPCResponse.JSONRPCError error = new McpSchema.JSONRPCResponse.JSONRPCError( + McpSchema.ErrorCodes.METHOD_NOT_FOUND, "Method not found", null); + transport.simulateIncomingMessage( + new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, error)); + }).expectError(McpError.class).verify(); + + session.close(); + } + + @Test + void testRequestTimeout() { + var transport = new MockMcpClientTransport(); + var session = new McpClientSession(TIMEOUT, transport, Map.of(), + Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> logger.info("Status update: {}", params))), + Function.identity()); + + Mono responseMono = session.sendRequest(TEST_METHOD, "test", responseType); + + // Verify timeout + StepVerifier.create(responseMono) + .expectError(java.util.concurrent.TimeoutException.class) + .verify(TIMEOUT.plusSeconds(1)); + + session.close(); + } + + @Test + void testSendNotification() { + var transport = new MockMcpClientTransport(); + var session = new McpClientSession(TIMEOUT, transport, Map.of(), + Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> logger.info("Status update: {}", params))), + Function.identity()); + + Map params = Map.of("key", "value"); + Mono notificationMono = session.sendNotification(TEST_NOTIFICATION, params); + + // Verify notification was sent + StepVerifier.create(notificationMono).consumeSubscriptionWith(response -> { + McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessage(); + assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCNotification.class); + McpSchema.JSONRPCNotification notification = (McpSchema.JSONRPCNotification) sentMessage; + assertThat(notification.method()).isEqualTo(TEST_NOTIFICATION); + assertThat(notification.params()).isEqualTo(params); + }).verifyComplete(); + + session.close(); + } + + @Test + void testRequestHandling() { + String echoMessage = "Hello MCP!"; + Map> requestHandlers = Map.of(ECHO_METHOD, + params -> Mono.just(params)); + var transport = new MockMcpClientTransport(); + var session = new McpClientSession(TIMEOUT, transport, requestHandlers, Map.of(), Function.identity()); + + // Simulate incoming request + McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, ECHO_METHOD, + "test-id", echoMessage); + transport.simulateIncomingMessage(request); + + // Verify response + McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessage(); + assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class); + McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; + assertThat(response.result()).isEqualTo(echoMessage); + assertThat(response.error()).isNull(); + + session.close(); + } + + @Test + void testNotificationHandling() { + Sinks.One receivedParams = Sinks.one(); + + var transport = new MockMcpClientTransport(); + var session = new McpClientSession(TIMEOUT, transport, Map.of(), + Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> receivedParams.tryEmitValue(params))), + Function.identity()); + + // Simulate incoming notification from the server + Map notificationParams = Map.of("status", "ready"); + + McpSchema.JSONRPCNotification notification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, + TEST_NOTIFICATION, notificationParams); + + transport.simulateIncomingMessage(notification); + + // Verify handler was called + assertThat(receivedParams.asMono().block(Duration.ofSeconds(1))).isEqualTo(notificationParams); + + session.close(); + } + + @Test + void testUnknownMethodHandling() { + + var transport = new MockMcpClientTransport(); + var session = new McpClientSession(TIMEOUT, transport, Map.of(), + Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> logger.info("Status update: {}", params))), + Function.identity()); + + // Simulate incoming request for unknown method + McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, "unknown.method", + "test-id", null); + transport.simulateIncomingMessage(request); + + // Verify error response + McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessage(); + assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class); + McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; + assertThat(response.error()).isNotNull(); + assertThat(response.error().code()).isEqualTo(McpSchema.ErrorCodes.METHOD_NOT_FOUND); + + session.close(); + } + + @Test + void testRequestHandlerThrowsMcpErrorWithJsonRpcError() { + // Setup: Create a request handler that throws McpError with custom error code and + // data + String testMethod = "test.customError"; + Map errorData = Map.of("customField", "customValue"); + McpClientSession.RequestHandler failingHandler = params -> Mono + .error(McpError.builder(123).message("Custom error message").data(errorData).build()); + + var transport = new MockMcpClientTransport(); + var session = new McpClientSession(TIMEOUT, transport, Map.of(testMethod, failingHandler), Map.of(), + Function.identity()); + + // Simulate incoming request that will trigger the error + McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, testMethod, + "test-id", null); + transport.simulateIncomingMessage(request); + + // Verify: The response should contain the custom error from McpError + McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessage(); + assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class); + McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; + assertThat(response.error()).isNotNull(); + assertThat(response.error().code()).isEqualTo(123); + assertThat(response.error().message()).isEqualTo("Custom error message"); + assertThat(response.error().data()).isEqualTo(errorData); + + session.close(); + } + + @Test + void testRequestHandlerThrowsGenericException() { + // Setup: Create a request handler that throws a generic RuntimeException + String testMethod = "test.genericError"; + RuntimeException exception = new RuntimeException("Something went wrong"); + McpClientSession.RequestHandler failingHandler = params -> Mono.error(exception); + + var transport = new MockMcpClientTransport(); + var session = new McpClientSession(TIMEOUT, transport, Map.of(testMethod, failingHandler), Map.of(), + Function.identity()); + + // Simulate incoming request that will trigger the error + McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, testMethod, + "test-id", null); + transport.simulateIncomingMessage(request); + + // Verify: The response should contain INTERNAL_ERROR with aggregated exception + // messages in data field + McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessage(); + assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class); + McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; + assertThat(response.error()).isNotNull(); + assertThat(response.error().code()).isEqualTo(McpSchema.ErrorCodes.INTERNAL_ERROR); + assertThat(response.error().message()).isEqualTo("Something went wrong"); + // Verify data field contains aggregated exception messages + assertThat(response.error().data()).isNotNull(); + assertThat(response.error().data().toString()).contains("RuntimeException"); + assertThat(response.error().data().toString()).contains("Something went wrong"); + + session.close(); + } + + @Test + void testRequestHandlerThrowsExceptionWithCause() { + // Setup: Create a request handler that throws an exception with a cause chain + String testMethod = "test.chainedError"; + RuntimeException rootCause = new IllegalArgumentException("Root cause message"); + RuntimeException middleCause = new IllegalStateException("Middle cause message", rootCause); + RuntimeException topException = new RuntimeException("Top level message", middleCause); + McpClientSession.RequestHandler failingHandler = params -> Mono.error(topException); + + var transport = new MockMcpClientTransport(); + var session = new McpClientSession(TIMEOUT, transport, Map.of(testMethod, failingHandler), Map.of(), + Function.identity()); + + // Simulate incoming request that will trigger the error + McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, testMethod, + "test-id", null); + transport.simulateIncomingMessage(request); + + // Verify: The response should contain INTERNAL_ERROR with full exception chain + // in data field + McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessage(); + assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class); + McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; + assertThat(response.error()).isNotNull(); + assertThat(response.error().code()).isEqualTo(McpSchema.ErrorCodes.INTERNAL_ERROR); + assertThat(response.error().message()).isEqualTo("Top level message"); + // Verify data field contains the full exception chain + String dataString = response.error().data().toString(); + assertThat(dataString).contains("RuntimeException"); + assertThat(dataString).contains("Top level message"); + assertThat(dataString).contains("IllegalStateException"); + assertThat(dataString).contains("Middle cause message"); + assertThat(dataString).contains("IllegalArgumentException"); + assertThat(dataString).contains("Root cause message"); + + session.close(); + } + + @Test + void testGracefulShutdown() { + var transport = new MockMcpClientTransport(); + var session = new McpClientSession(TIMEOUT, transport, Map.of(), + Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> logger.info("Status update: {}", params))), + Function.identity()); + + StepVerifier.create(session.closeGracefully()).verifyComplete(); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/spec/McpErrorTest.java b/mcp-core/src/test/java/io/modelcontextprotocol/spec/McpErrorTest.java new file mode 100644 index 000000000..0978ffe0b --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/spec/McpErrorTest.java @@ -0,0 +1,22 @@ +package io.modelcontextprotocol.spec; + +import org.junit.jupiter.api.Test; + +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertEquals; + +class McpErrorTest { + + @Test + void testNotFound() { + String uri = "file:///nonexistent.txt"; + McpError mcpError = McpError.RESOURCE_NOT_FOUND.apply(uri); + assertNotNull(mcpError.getJsonRpcError()); + assertEquals(-32002, mcpError.getJsonRpcError().code()); + assertEquals("Resource not found", mcpError.getJsonRpcError().message()); + assertEquals(Map.of("uri", uri), mcpError.getJsonRpcError().data()); + } + +} \ No newline at end of file diff --git a/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java similarity index 82% rename from mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java index a5b2137fd..6b0004cb9 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java @@ -4,6 +4,7 @@ package io.modelcontextprotocol.spec; +import static io.modelcontextprotocol.util.McpJsonMapperUtils.JSON_MAPPER; import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson; import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json; import static org.assertj.core.api.Assertions.assertThat; @@ -17,7 +18,6 @@ import org.junit.jupiter.api.Test; -import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.exc.InvalidTypeIdException; import io.modelcontextprotocol.spec.McpSchema.TextResourceContents; @@ -29,14 +29,12 @@ */ public class McpSchemaTests { - ObjectMapper mapper = new ObjectMapper(); - // Content Types Tests @Test void testTextContent() throws Exception { McpSchema.TextContent test = new McpSchema.TextContent("XXX"); - String value = mapper.writeValueAsString(test); + String value = JSON_MAPPER.writeValueAsString(test); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) @@ -47,7 +45,7 @@ void testTextContent() throws Exception { @Test void testTextContentDeserialization() throws Exception { - McpSchema.TextContent textContent = mapper.readValue(""" + McpSchema.TextContent textContent = JSON_MAPPER.readValue(""" {"type":"text","text":"XXX","_meta":{"metaKey":"metaValue"}}""", McpSchema.TextContent.class); assertThat(textContent).isNotNull(); @@ -59,7 +57,7 @@ void testTextContentDeserialization() throws Exception { @Test void testContentDeserializationWrongType() throws Exception { - assertThatThrownBy(() -> mapper.readValue(""" + assertThatThrownBy(() -> JSON_MAPPER.readValue(""" {"type":"WRONG","text":"XXX"}""", McpSchema.TextContent.class)) .isInstanceOf(InvalidTypeIdException.class) .hasMessageContaining( @@ -69,7 +67,7 @@ void testContentDeserializationWrongType() throws Exception { @Test void testImageContent() throws Exception { McpSchema.ImageContent test = new McpSchema.ImageContent(null, null, "base64encodeddata", "image/png"); - String value = mapper.writeValueAsString(test); + String value = JSON_MAPPER.writeValueAsString(test); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) @@ -80,7 +78,7 @@ void testImageContent() throws Exception { @Test void testImageContentDeserialization() throws Exception { - McpSchema.ImageContent imageContent = mapper.readValue(""" + McpSchema.ImageContent imageContent = JSON_MAPPER.readValue(""" {"type":"image","data":"base64encodeddata","mimeType":"image/png","_meta":{"metaKey":"metaValue"}}""", McpSchema.ImageContent.class); assertThat(imageContent).isNotNull(); @@ -93,7 +91,7 @@ void testImageContentDeserialization() throws Exception { @Test void testAudioContent() throws Exception { McpSchema.AudioContent audioContent = new McpSchema.AudioContent(null, "base64encodeddata", "audio/wav"); - String value = mapper.writeValueAsString(audioContent); + String value = JSON_MAPPER.writeValueAsString(audioContent); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) @@ -104,7 +102,7 @@ void testAudioContent() throws Exception { @Test void testAudioContentDeserialization() throws Exception { - McpSchema.AudioContent audioContent = mapper.readValue(""" + McpSchema.AudioContent audioContent = JSON_MAPPER.readValue(""" {"type":"audio","data":"base64encodeddata","mimeType":"audio/wav","_meta":{"metaKey":"metaValue"}}""", McpSchema.AudioContent.class); assertThat(audioContent).isNotNull(); @@ -140,7 +138,7 @@ void testCreateMessageRequestWithMeta() throws Exception { .meta(meta) .build(); - String value = mapper.writeValueAsString(request); + String value = JSON_MAPPER.writeValueAsString(request); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -158,7 +156,7 @@ void testEmbeddedResource() throws Exception { McpSchema.EmbeddedResource test = new McpSchema.EmbeddedResource(null, null, resourceContents); - String value = mapper.writeValueAsString(test); + String value = JSON_MAPPER.writeValueAsString(test); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -169,7 +167,7 @@ void testEmbeddedResource() throws Exception { @Test void testEmbeddedResourceDeserialization() throws Exception { - McpSchema.EmbeddedResource embeddedResource = mapper.readValue( + McpSchema.EmbeddedResource embeddedResource = JSON_MAPPER.readValue( """ {"type":"resource","resource":{"uri":"resource://test","mimeType":"text/plain","text":"Sample resource content"},"_meta":{"metaKey":"metaValue"}}""", McpSchema.EmbeddedResource.class); @@ -189,7 +187,7 @@ void testEmbeddedResourceWithBlobContents() throws Exception { McpSchema.EmbeddedResource test = new McpSchema.EmbeddedResource(null, null, resourceContents); - String value = mapper.writeValueAsString(test); + String value = JSON_MAPPER.writeValueAsString(test); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -200,7 +198,7 @@ void testEmbeddedResourceWithBlobContents() throws Exception { @Test void testEmbeddedResourceWithBlobContentsDeserialization() throws Exception { - McpSchema.EmbeddedResource embeddedResource = mapper.readValue( + McpSchema.EmbeddedResource embeddedResource = JSON_MAPPER.readValue( """ {"type":"resource","resource":{"uri":"resource://test","mimeType":"application/octet-stream","blob":"base64encodedblob","_meta":{"metaKey":"metaValue"}}}""", McpSchema.EmbeddedResource.class); @@ -219,7 +217,7 @@ void testResourceLink() throws Exception { McpSchema.ResourceLink resourceLink = new McpSchema.ResourceLink("main.rs", "Main file", "file:///project/src/main.rs", "Primary application entry point", "text/x-rust", null, null, Map.of("metaKey", "metaValue")); - String value = mapper.writeValueAsString(resourceLink); + String value = JSON_MAPPER.writeValueAsString(resourceLink); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) @@ -231,7 +229,7 @@ void testResourceLink() throws Exception { @Test void testResourceLinkDeserialization() throws Exception { - McpSchema.ResourceLink resourceLink = mapper.readValue( + McpSchema.ResourceLink resourceLink = JSON_MAPPER.readValue( """ {"type":"resource_link","name":"main.rs","uri":"file:///project/src/main.rs","description":"Primary application entry point","mimeType":"text/x-rust","_meta":{"metaKey":"metaValue"}}""", McpSchema.ResourceLink.class); @@ -254,7 +252,7 @@ void testJSONRPCRequest() throws Exception { McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, "method_name", 1, params); - String value = mapper.writeValueAsString(request); + String value = JSON_MAPPER.writeValueAsString(request); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -270,7 +268,7 @@ void testJSONRPCNotification() throws Exception { McpSchema.JSONRPCNotification notification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, "notification_method", params); - String value = mapper.writeValueAsString(notification); + String value = JSON_MAPPER.writeValueAsString(notification); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -285,7 +283,7 @@ void testJSONRPCResponse() throws Exception { McpSchema.JSONRPCResponse response = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, 1, result, null); - String value = mapper.writeValueAsString(response); + String value = JSON_MAPPER.writeValueAsString(response); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -300,7 +298,7 @@ void testJSONRPCResponseWithError() throws Exception { McpSchema.JSONRPCResponse response = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, 1, null, error); - String value = mapper.writeValueAsString(response); + String value = JSON_MAPPER.writeValueAsString(response); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -323,7 +321,7 @@ void testInitializeRequest() throws Exception { McpSchema.InitializeRequest request = new McpSchema.InitializeRequest(ProtocolVersions.MCP_2024_11_05, capabilities, clientInfo, meta); - String value = mapper.writeValueAsString(request); + String value = JSON_MAPPER.writeValueAsString(request); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -346,7 +344,7 @@ void testInitializeResult() throws Exception { McpSchema.InitializeResult result = new McpSchema.InitializeResult(ProtocolVersions.MCP_2024_11_05, capabilities, serverInfo, "Server initialized successfully"); - String value = mapper.writeValueAsString(result); + String value = JSON_MAPPER.writeValueAsString(result); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -365,7 +363,7 @@ void testResource() throws Exception { McpSchema.Resource resource = new McpSchema.Resource("resource://test", "Test Resource", "A test resource", "text/plain", annotations); - String value = mapper.writeValueAsString(resource); + String value = JSON_MAPPER.writeValueAsString(resource); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -389,7 +387,7 @@ void testResourceBuilder() throws Exception { .meta(Map.of("metaKey", "metaValue")) .build(); - String value = mapper.writeValueAsString(resource); + String value = JSON_MAPPER.writeValueAsString(resource); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -436,7 +434,7 @@ void testResourceTemplate() throws Exception { McpSchema.ResourceTemplate template = new McpSchema.ResourceTemplate("resource://{param}/test", "Test Template", "Test Template", "A test resource template", "text/plain", annotations, meta); - String value = mapper.writeValueAsString(template); + String value = JSON_MAPPER.writeValueAsString(template); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -458,7 +456,7 @@ void testListResourcesResult() throws Exception { McpSchema.ListResourcesResult result = new McpSchema.ListResourcesResult(Arrays.asList(resource1, resource2), "next-cursor", meta); - String value = mapper.writeValueAsString(result); + String value = JSON_MAPPER.writeValueAsString(result); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -478,7 +476,7 @@ void testListResourceTemplatesResult() throws Exception { McpSchema.ListResourceTemplatesResult result = new McpSchema.ListResourceTemplatesResult( Arrays.asList(template1, template2), "next-cursor"); - String value = mapper.writeValueAsString(result); + String value = JSON_MAPPER.writeValueAsString(result); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -492,7 +490,7 @@ void testReadResourceRequest() throws Exception { McpSchema.ReadResourceRequest request = new McpSchema.ReadResourceRequest("resource://test", Map.of("metaKey", "metaValue")); - String value = mapper.writeValueAsString(request); + String value = JSON_MAPPER.writeValueAsString(request); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -507,7 +505,7 @@ void testReadResourceRequestWithMeta() throws Exception { McpSchema.ReadResourceRequest request = new McpSchema.ReadResourceRequest("resource://test", meta); - String value = mapper.writeValueAsString(request); + String value = JSON_MAPPER.writeValueAsString(request); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -521,7 +519,7 @@ void testReadResourceRequestWithMeta() throws Exception { @Test void testReadResourceRequestDeserialization() throws Exception { - McpSchema.ReadResourceRequest request = mapper.readValue(""" + McpSchema.ReadResourceRequest request = JSON_MAPPER.readValue(""" {"uri":"resource://test","_meta":{"progressToken":"test-token"}}""", McpSchema.ReadResourceRequest.class); @@ -541,7 +539,7 @@ void testReadResourceResult() throws Exception { McpSchema.ReadResourceResult result = new McpSchema.ReadResourceResult(Arrays.asList(contents1, contents2), Map.of("metaKey", "metaValue")); - String value = mapper.writeValueAsString(result); + String value = JSON_MAPPER.writeValueAsString(result); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -562,7 +560,7 @@ void testPrompt() throws Exception { McpSchema.Prompt prompt = new McpSchema.Prompt("test-prompt", "Test Prompt", "A test prompt", Arrays.asList(arg1, arg2), Map.of("metaKey", "metaValue")); - String value = mapper.writeValueAsString(prompt); + String value = JSON_MAPPER.writeValueAsString(prompt); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -577,7 +575,7 @@ void testPromptMessage() throws Exception { McpSchema.PromptMessage message = new McpSchema.PromptMessage(McpSchema.Role.USER, content); - String value = mapper.writeValueAsString(message); + String value = JSON_MAPPER.writeValueAsString(message); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -598,7 +596,7 @@ void testListPromptsResult() throws Exception { McpSchema.ListPromptsResult result = new McpSchema.ListPromptsResult(Arrays.asList(prompt1, prompt2), "next-cursor"); - String value = mapper.writeValueAsString(result); + String value = JSON_MAPPER.writeValueAsString(result); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -615,7 +613,7 @@ void testGetPromptRequest() throws Exception { McpSchema.GetPromptRequest request = new McpSchema.GetPromptRequest("test-prompt", arguments); - assertThat(mapper.readValue(""" + assertThat(JSON_MAPPER.readValue(""" {"name":"test-prompt","arguments":{"arg1":"value1","arg2":42}}""", McpSchema.GetPromptRequest.class)) .isEqualTo(request); } @@ -631,7 +629,7 @@ void testGetPromptRequestWithMeta() throws Exception { McpSchema.GetPromptRequest request = new McpSchema.GetPromptRequest("test-prompt", arguments, meta); - String value = mapper.writeValueAsString(request); + String value = JSON_MAPPER.writeValueAsString(request); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -656,7 +654,7 @@ void testGetPromptResult() throws Exception { McpSchema.GetPromptResult result = new McpSchema.GetPromptResult("A test prompt result", Arrays.asList(message1, message2)); - String value = mapper.writeValueAsString(result); + String value = JSON_MAPPER.writeValueAsString(result); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) @@ -696,16 +694,16 @@ void testJsonSchema() throws Exception { """; // Deserialize the original string to a JsonSchema object - McpSchema.JsonSchema schema = mapper.readValue(schemaJson, McpSchema.JsonSchema.class); + McpSchema.JsonSchema schema = JSON_MAPPER.readValue(schemaJson, McpSchema.JsonSchema.class); // Serialize the object back to a string - String serialized = mapper.writeValueAsString(schema); + String serialized = JSON_MAPPER.writeValueAsString(schema); // Deserialize again - McpSchema.JsonSchema deserialized = mapper.readValue(serialized, McpSchema.JsonSchema.class); + McpSchema.JsonSchema deserialized = JSON_MAPPER.readValue(serialized, McpSchema.JsonSchema.class); // Serialize one more time and compare with the first serialization - String serializedAgain = mapper.writeValueAsString(deserialized); + String serializedAgain = JSON_MAPPER.writeValueAsString(deserialized); // The two serialized strings should be the same assertThatJson(serializedAgain).when(Option.IGNORING_ARRAY_ORDER).isEqualTo(json(serialized)); @@ -739,16 +737,16 @@ void testJsonSchemaWithDefinitions() throws Exception { """; // Deserialize the original string to a JsonSchema object - McpSchema.JsonSchema schema = mapper.readValue(schemaJson, McpSchema.JsonSchema.class); + McpSchema.JsonSchema schema = JSON_MAPPER.readValue(schemaJson, McpSchema.JsonSchema.class); // Serialize the object back to a string - String serialized = mapper.writeValueAsString(schema); + String serialized = JSON_MAPPER.writeValueAsString(schema); // Deserialize again - McpSchema.JsonSchema deserialized = mapper.readValue(serialized, McpSchema.JsonSchema.class); + McpSchema.JsonSchema deserialized = JSON_MAPPER.readValue(serialized, McpSchema.JsonSchema.class); // Serialize one more time and compare with the first serialization - String serializedAgain = mapper.writeValueAsString(deserialized); + String serializedAgain = JSON_MAPPER.writeValueAsString(deserialized); // The two serialized strings should be the same assertThatJson(serializedAgain).when(Option.IGNORING_ARRAY_ORDER).isEqualTo(json(serialized)); @@ -771,9 +769,13 @@ void testTool() throws Exception { } """; - McpSchema.Tool tool = new McpSchema.Tool("test-tool", "A test tool", schemaJson); + McpSchema.Tool tool = McpSchema.Tool.builder() + .name("test-tool") + .description("A test tool") + .inputSchema(JSON_MAPPER, schemaJson) + .build(); - String value = mapper.writeValueAsString(tool); + String value = JSON_MAPPER.writeValueAsString(tool); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -805,16 +807,20 @@ void testToolWithComplexSchema() throws Exception { } """; - McpSchema.Tool tool = new McpSchema.Tool("addressTool", "Handles addresses", complexSchemaJson); + McpSchema.Tool tool = McpSchema.Tool.builder() + .name("addressTool") + .title("Handles addresses") + .inputSchema(JSON_MAPPER, complexSchemaJson) + .build(); // Serialize the tool to a string - String serialized = mapper.writeValueAsString(tool); + String serialized = JSON_MAPPER.writeValueAsString(tool); // Deserialize back to a Tool object - McpSchema.Tool deserializedTool = mapper.readValue(serialized, McpSchema.Tool.class); + McpSchema.Tool deserializedTool = JSON_MAPPER.readValue(serialized, McpSchema.Tool.class); // Serialize again and compare with first serialization - String serializedAgain = mapper.writeValueAsString(deserializedTool); + String serializedAgain = JSON_MAPPER.writeValueAsString(deserializedTool); // The two serialized strings should be the same assertThatJson(serializedAgain).when(Option.IGNORING_ARRAY_ORDER).isEqualTo(json(serialized)); @@ -841,11 +847,16 @@ void testToolWithMeta() throws Exception { } """; - McpSchema.JsonSchema schema = mapper.readValue(schemaJson, McpSchema.JsonSchema.class); + McpSchema.JsonSchema schema = JSON_MAPPER.readValue(schemaJson, McpSchema.JsonSchema.class); Map meta = Map.of("metaKey", "metaValue"); - McpSchema.Tool tool = new McpSchema.Tool("addressTool", "addressTool", "Handles addresses", schema, null, null, - meta); + McpSchema.Tool tool = McpSchema.Tool.builder() + .name("addressTool") + .title("addressTool") + .description("Handles addresses") + .inputSchema(schema) + .meta(meta) + .build(); // Verify that meta value was preserved assertThat(tool.meta()).isNotNull(); @@ -871,9 +882,14 @@ void testToolWithAnnotations() throws Exception { McpSchema.ToolAnnotations annotations = new McpSchema.ToolAnnotations("A test tool", false, false, false, false, false); - McpSchema.Tool tool = new McpSchema.Tool("test-tool", "A test tool", schemaJson, annotations); + McpSchema.Tool tool = McpSchema.Tool.builder() + .name("test-tool") + .description("A test tool") + .inputSchema(JSON_MAPPER, schemaJson) + .annotations(annotations) + .build(); - String value = mapper.writeValueAsString(tool); + String value = JSON_MAPPER.writeValueAsString(tool); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -934,9 +950,14 @@ void testToolWithOutputSchema() throws Exception { } """; - McpSchema.Tool tool = new McpSchema.Tool("test-tool", "A test tool", inputSchemaJson, outputSchemaJson, null); + McpSchema.Tool tool = McpSchema.Tool.builder() + .name("test-tool") + .description("A test tool") + .inputSchema(JSON_MAPPER, inputSchemaJson) + .outputSchema(JSON_MAPPER, outputSchemaJson) + .build(); - String value = mapper.writeValueAsString(tool); + String value = JSON_MAPPER.writeValueAsString(tool); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -996,10 +1017,15 @@ void testToolWithOutputSchemaAndAnnotations() throws Exception { McpSchema.ToolAnnotations annotations = new McpSchema.ToolAnnotations("A test tool with output", true, false, true, false, true); - McpSchema.Tool tool = new McpSchema.Tool("test-tool", "A test tool", inputSchemaJson, outputSchemaJson, - annotations); + McpSchema.Tool tool = McpSchema.Tool.builder() + .name("test-tool") + .description("A test tool") + .inputSchema(JSON_MAPPER, inputSchemaJson) + .outputSchema(JSON_MAPPER, outputSchemaJson) + .annotations(annotations) + .build(); - String value = mapper.writeValueAsString(tool); + String value = JSON_MAPPER.writeValueAsString(tool); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -1063,7 +1089,7 @@ void testToolDeserialization() throws Exception { } """; - McpSchema.Tool tool = mapper.readValue(toolJson, McpSchema.Tool.class); + McpSchema.Tool tool = JSON_MAPPER.readValue(toolJson, McpSchema.Tool.class); assertThat(tool).isNotNull(); assertThat(tool.name()).isEqualTo("test-tool"); @@ -1097,7 +1123,7 @@ void testToolDeserializationWithoutOutputSchema() throws Exception { } """; - McpSchema.Tool tool = mapper.readValue(toolJson, McpSchema.Tool.class); + McpSchema.Tool tool = JSON_MAPPER.readValue(toolJson, McpSchema.Tool.class); assertThat(tool).isNotNull(); assertThat(tool.name()).isEqualTo("test-tool"); @@ -1115,7 +1141,7 @@ void testCallToolRequest() throws Exception { McpSchema.CallToolRequest request = new McpSchema.CallToolRequest("test-tool", arguments); - String value = mapper.writeValueAsString(request); + String value = JSON_MAPPER.writeValueAsString(request); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) @@ -1127,14 +1153,14 @@ void testCallToolRequest() throws Exception { @Test void testCallToolRequestJsonArguments() throws Exception { - McpSchema.CallToolRequest request = new McpSchema.CallToolRequest("test-tool", """ + McpSchema.CallToolRequest request = new McpSchema.CallToolRequest(JSON_MAPPER, "test-tool", """ { "name": "test", "value": 42 } """); - String value = mapper.writeValueAsString(request); + String value = JSON_MAPPER.writeValueAsString(request); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) @@ -1151,7 +1177,7 @@ void testCallToolRequestWithMeta() throws Exception { .arguments(Map.of("name", "test", "value", 42)) .progressToken("tool-progress-123") .build(); - String value = mapper.writeValueAsString(request); + String value = JSON_MAPPER.writeValueAsString(request); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) @@ -1170,14 +1196,18 @@ void testCallToolRequestBuilderWithJsonArguments() throws Exception { Map meta = new HashMap<>(); meta.put("progressToken", "json-builder-789"); - McpSchema.CallToolRequest request = McpSchema.CallToolRequest.builder().name("test-tool").arguments(""" - { - "name": "test", - "value": 42 - } - """).meta(meta).build(); + McpSchema.CallToolRequest request = McpSchema.CallToolRequest.builder() + .name("test-tool") + .arguments(JSON_MAPPER, """ + { + "name": "test", + "value": 42 + } + """) + .meta(meta) + .build(); - String value = mapper.writeValueAsString(request); + String value = JSON_MAPPER.writeValueAsString(request); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) @@ -1206,9 +1236,11 @@ void testCallToolRequestBuilderNameRequired() { void testCallToolResult() throws Exception { McpSchema.TextContent content = new McpSchema.TextContent("Tool execution result"); - McpSchema.CallToolResult result = new McpSchema.CallToolResult(Collections.singletonList(content), false); + McpSchema.CallToolResult result = McpSchema.CallToolResult.builder() + .content(Collections.singletonList(content)) + .build(); - String value = mapper.writeValueAsString(result); + String value = JSON_MAPPER.writeValueAsString(result); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) @@ -1224,7 +1256,7 @@ void testCallToolResultBuilder() throws Exception { .isError(false) .build(); - String value = mapper.writeValueAsString(result); + String value = JSON_MAPPER.writeValueAsString(result); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) @@ -1244,7 +1276,7 @@ void testCallToolResultBuilderWithMultipleContents() throws Exception { .isError(false) .build(); - String value = mapper.writeValueAsString(result); + String value = JSON_MAPPER.writeValueAsString(result); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) @@ -1262,7 +1294,7 @@ void testCallToolResultBuilderWithContentList() throws Exception { McpSchema.CallToolResult result = McpSchema.CallToolResult.builder().content(contents).isError(true).build(); - String value = mapper.writeValueAsString(result); + String value = JSON_MAPPER.writeValueAsString(result); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) @@ -1279,7 +1311,7 @@ void testCallToolResultBuilderWithErrorResult() throws Exception { .isError(true) .build(); - String value = mapper.writeValueAsString(result); + String value = JSON_MAPPER.writeValueAsString(result); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) @@ -1297,8 +1329,8 @@ void testCallToolResultStringConstructor() throws Exception { .isError(false) .build(); - String value1 = mapper.writeValueAsString(result1); - String value2 = mapper.writeValueAsString(result2); + String value1 = JSON_MAPPER.writeValueAsString(result1); + String value2 = JSON_MAPPER.writeValueAsString(result2); // Both should produce the same JSON assertThat(value1).isEqualTo(value2); @@ -1336,7 +1368,7 @@ void testCreateMessageRequest() throws Exception { .metadata(metadata) .build(); - String value = mapper.writeValueAsString(request); + String value = JSON_MAPPER.writeValueAsString(request); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) @@ -1357,7 +1389,7 @@ void testCreateMessageResult() throws Exception { .stopReason(McpSchema.CreateMessageResult.StopReason.END_TURN) .build(); - String value = mapper.writeValueAsString(result); + String value = JSON_MAPPER.writeValueAsString(result); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) @@ -1372,7 +1404,7 @@ void testCreateMessageResultUnknownStopReason() throws Exception { String input = """ {"role":"assistant","content":{"type":"text","text":"Assistant response"},"model":"gpt-4","stopReason":"arbitrary value"}"""; - McpSchema.CreateMessageResult value = mapper.readValue(input, McpSchema.CreateMessageResult.class); + McpSchema.CreateMessageResult value = JSON_MAPPER.readValue(input, McpSchema.CreateMessageResult.class); McpSchema.TextContent expectedContent = new McpSchema.TextContent("Assistant response"); McpSchema.CreateMessageResult expected = McpSchema.CreateMessageResult.builder() @@ -1393,7 +1425,7 @@ void testCreateElicitationRequest() throws Exception { Map.of("foo", Map.of("type", "string")))) .build(); - String value = mapper.writeValueAsString(request); + String value = JSON_MAPPER.writeValueAsString(request); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) @@ -1409,7 +1441,7 @@ void testCreateElicitationResult() throws Exception { .message(McpSchema.ElicitResult.Action.ACCEPT) .build(); - String value = mapper.writeValueAsString(result); + String value = JSON_MAPPER.writeValueAsString(result); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) @@ -1432,7 +1464,7 @@ void testElicitRequestWithMeta() throws Exception { .meta(meta) .build(); - String value = mapper.writeValueAsString(request); + String value = JSON_MAPPER.writeValueAsString(request); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -1449,7 +1481,7 @@ void testElicitRequestWithMeta() throws Exception { void testPaginatedRequestNoArgs() throws Exception { McpSchema.PaginatedRequest request = new McpSchema.PaginatedRequest(); - String value = mapper.writeValueAsString(request); + String value = JSON_MAPPER.writeValueAsString(request); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -1465,7 +1497,7 @@ void testPaginatedRequestNoArgs() throws Exception { void testPaginatedRequestWithCursor() throws Exception { McpSchema.PaginatedRequest request = new McpSchema.PaginatedRequest("cursor123"); - String value = mapper.writeValueAsString(request); + String value = JSON_MAPPER.writeValueAsString(request); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -1484,7 +1516,7 @@ void testPaginatedRequestWithMeta() throws Exception { McpSchema.PaginatedRequest request = new McpSchema.PaginatedRequest("cursor123", meta); - String value = mapper.writeValueAsString(request); + String value = JSON_MAPPER.writeValueAsString(request); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -1498,7 +1530,7 @@ void testPaginatedRequestWithMeta() throws Exception { @Test void testPaginatedRequestDeserialization() throws Exception { - McpSchema.PaginatedRequest request = mapper.readValue(""" + McpSchema.PaginatedRequest request = JSON_MAPPER.readValue(""" {"cursor":"test-cursor","_meta":{"progressToken":"test-token"}}""", McpSchema.PaginatedRequest.class); assertThat(request.cursor()).isEqualTo("test-cursor"); @@ -1516,7 +1548,7 @@ void testCompleteRequest() throws Exception { McpSchema.CompleteRequest request = new McpSchema.CompleteRequest(promptRef, argument); - String value = mapper.writeValueAsString(request); + String value = JSON_MAPPER.writeValueAsString(request); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -1540,7 +1572,7 @@ void testCompleteRequestWithMeta() throws Exception { McpSchema.CompleteRequest request = new McpSchema.CompleteRequest(resourceRef, argument, meta, null); - String value = mapper.writeValueAsString(request); + String value = JSON_MAPPER.writeValueAsString(request); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -1559,7 +1591,7 @@ void testCompleteRequestWithMeta() throws Exception { void testRoot() throws Exception { McpSchema.Root root = new McpSchema.Root("file:///path/to/root", "Test Root", Map.of("metaKey", "metaValue")); - String value = mapper.writeValueAsString(root); + String value = JSON_MAPPER.writeValueAsString(root); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -1575,7 +1607,7 @@ void testListRootsResult() throws Exception { McpSchema.ListRootsResult result = new McpSchema.ListRootsResult(Arrays.asList(root1, root2), "next-cursor"); - String value = mapper.writeValueAsString(result); + String value = JSON_MAPPER.writeValueAsString(result); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) @@ -1586,6 +1618,107 @@ void testListRootsResult() throws Exception { } + // Elicitation Capability Tests (Issue #724) + + @Test + void testElicitationCapabilityWithFormField() throws Exception { + // Test that elicitation with "form" field can be deserialized (2025-11-25 spec) + String json = """ + {"protocolVersion":"2024-11-05","capabilities":{"elicitation":{"form":{}}},"clientInfo":{"name":"test-client","version":"1.0.0"}} + """; + + McpSchema.InitializeRequest request = JSON_MAPPER.readValue(json, McpSchema.InitializeRequest.class); + + assertThat(request).isNotNull(); + assertThat(request.capabilities()).isNotNull(); + assertThat(request.capabilities().elicitation()).isNotNull(); + } + + @Test + void testElicitationCapabilityWithFormAndUrlFields() throws Exception { + // Test that elicitation with both "form" and "url" fields can be deserialized + String json = """ + {"protocolVersion":"2024-11-05","capabilities":{"elicitation":{"form":{},"url":{}}},"clientInfo":{"name":"test-client","version":"1.0.0"}} + """; + + McpSchema.InitializeRequest request = JSON_MAPPER.readValue(json, McpSchema.InitializeRequest.class); + + assertThat(request).isNotNull(); + assertThat(request.capabilities()).isNotNull(); + assertThat(request.capabilities().elicitation()).isNotNull(); + } + + @Test + void testElicitationCapabilityBackwardCompatibilityEmptyObject() throws Exception { + // Test backward compatibility: empty elicitation {} should still work + String json = """ + {"protocolVersion":"2024-11-05","capabilities":{"elicitation":{}},"clientInfo":{"name":"test-client","version":"1.0.0"}} + """; + + McpSchema.InitializeRequest request = JSON_MAPPER.readValue(json, McpSchema.InitializeRequest.class); + + assertThat(request).isNotNull(); + assertThat(request.capabilities()).isNotNull(); + assertThat(request.capabilities().elicitation()).isNotNull(); + } + + @Test + void testElicitationCapabilityBuilderBackwardCompatibility() throws Exception { + // Test that the existing builder API still works and produces valid JSON + McpSchema.ClientCapabilities capabilities = McpSchema.ClientCapabilities.builder().elicitation().build(); + + assertThat(capabilities.elicitation()).isNotNull(); + + // Serialize and verify it produces valid JSON (should be {} for backward compat) + String json = JSON_MAPPER.writeValueAsString(capabilities); + assertThat(json).contains("\"elicitation\""); + } + + @Test + void testElicitationCapabilitySerializationRoundTrip() throws Exception { + // Test that serialization and deserialization round-trip works + McpSchema.ClientCapabilities original = McpSchema.ClientCapabilities.builder().elicitation().build(); + + String json = JSON_MAPPER.writeValueAsString(original); + McpSchema.ClientCapabilities deserialized = JSON_MAPPER.readValue(json, McpSchema.ClientCapabilities.class); + + assertThat(deserialized.elicitation()).isNotNull(); + } + + @Test + void testElicitationCapabilityBuilderWithFormAndUrl() throws Exception { + // Test the new builder method that explicitly sets form and url support + McpSchema.ClientCapabilities capabilities = McpSchema.ClientCapabilities.builder() + .elicitation(true, true) + .build(); + + assertThat(capabilities.elicitation()).isNotNull(); + assertThat(capabilities.elicitation().form()).isNotNull(); + assertThat(capabilities.elicitation().url()).isNotNull(); + + // Verify serialization produces the expected JSON + String json = JSON_MAPPER.writeValueAsString(capabilities); + assertThatJson(json).when(Option.IGNORING_ARRAY_ORDER).isObject().containsKey("elicitation"); + assertThat(json).contains("\"form\""); + assertThat(json).contains("\"url\""); + } + + @Test + void testElicitationCapabilityBuilderFormOnly() throws Exception { + // Test builder with form only + McpSchema.ClientCapabilities capabilities = McpSchema.ClientCapabilities.builder() + .elicitation(true, false) + .build(); + + assertThat(capabilities.elicitation()).isNotNull(); + assertThat(capabilities.elicitation().form()).isNotNull(); + assertThat(capabilities.elicitation().url()).isNull(); + + String json = JSON_MAPPER.writeValueAsString(capabilities); + assertThat(json).contains("\"form\""); + assertThat(json).doesNotContain("\"url\""); + } + // Progress Notification Tests @Test @@ -1593,7 +1726,7 @@ void testProgressNotificationWithMessage() throws Exception { McpSchema.ProgressNotification notification = new McpSchema.ProgressNotification("progress-token-123", 0.5, 1.0, "Processing file 1 of 2", Map.of("key", "value")); - String value = mapper.writeValueAsString(notification); + String value = JSON_MAPPER.writeValueAsString(notification); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -1604,7 +1737,7 @@ void testProgressNotificationWithMessage() throws Exception { @Test void testProgressNotificationDeserialization() throws Exception { - McpSchema.ProgressNotification notification = mapper.readValue( + McpSchema.ProgressNotification notification = JSON_MAPPER.readValue( """ {"progressToken":"token-456","progress":0.75,"total":1.0,"message":"Almost done","_meta":{"key":"value"}}""", McpSchema.ProgressNotification.class); @@ -1621,7 +1754,7 @@ void testProgressNotificationWithoutMessage() throws Exception { McpSchema.ProgressNotification notification = new McpSchema.ProgressNotification("progress-token-789", 0.25, null, null); - String value = mapper.writeValueAsString(notification); + String value = JSON_MAPPER.writeValueAsString(notification); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/spec/PromptReferenceEqualsTest.java b/mcp-core/src/test/java/io/modelcontextprotocol/spec/PromptReferenceEqualsTest.java new file mode 100644 index 000000000..1d7be0b51 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/spec/PromptReferenceEqualsTest.java @@ -0,0 +1,100 @@ +/* +* Copyright 2025 - 2025 the original author or authors. +*/ + +package io.modelcontextprotocol.spec; + +import io.modelcontextprotocol.spec.McpSchema.PromptReference; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertEquals; + +/** + * Test class to verify the equals method implementation for PromptReference. + */ +class PromptReferenceEqualsTest { + + @Test + void testEqualsWithSameIdentifierAndType() { + McpSchema.PromptReference ref1 = new McpSchema.PromptReference(PromptReference.TYPE, "test-prompt", + "Test Title"); + McpSchema.PromptReference ref2 = new McpSchema.PromptReference(PromptReference.TYPE, "test-prompt", + "Different Title"); + + assertTrue(ref1.equals(ref2), "PromptReferences with same identifier and type should be equal"); + assertEquals(ref1.hashCode(), ref2.hashCode(), "Equal objects should have same hash code"); + } + + @Test + void testEqualsWithDifferentIdentifier() { + McpSchema.PromptReference ref1 = new McpSchema.PromptReference(PromptReference.TYPE, "test-prompt-1", + "Test Title"); + McpSchema.PromptReference ref2 = new McpSchema.PromptReference(PromptReference.TYPE, "test-prompt-2", + "Test Title"); + + assertFalse(ref1.equals(ref2), "PromptReferences with different identifiers should not be equal"); + } + + @Test + void testEqualsWithDifferentType() { + McpSchema.PromptReference ref1 = new McpSchema.PromptReference(PromptReference.TYPE, "test-prompt", + "Test Title"); + McpSchema.PromptReference ref2 = new McpSchema.PromptReference("ref/other", "test-prompt", "Test Title"); + + assertFalse(ref1.equals(ref2), "PromptReferences with different types should not be equal"); + } + + @Test + void testEqualsWithNull() { + McpSchema.PromptReference ref1 = new McpSchema.PromptReference(PromptReference.TYPE, "test-prompt", + "Test Title"); + + assertFalse(ref1.equals(null), "PromptReference should not be equal to null"); + } + + @Test + void testEqualsWithDifferentClass() { + McpSchema.PromptReference ref1 = new McpSchema.PromptReference(PromptReference.TYPE, "test-prompt", + "Test Title"); + String other = "not a PromptReference"; + + assertFalse(ref1.equals(other), "PromptReference should not be equal to different class"); + } + + @Test + void testEqualsWithSameInstance() { + McpSchema.PromptReference ref1 = new McpSchema.PromptReference(PromptReference.TYPE, "test-prompt", + "Test Title"); + + assertTrue(ref1.equals(ref1), "PromptReference should be equal to itself"); + } + + @Test + void testEqualsIgnoresTitle() { + McpSchema.PromptReference ref1 = new McpSchema.PromptReference(PromptReference.TYPE, "test-prompt", "Title 1"); + McpSchema.PromptReference ref2 = new McpSchema.PromptReference(PromptReference.TYPE, "test-prompt", "Title 2"); + McpSchema.PromptReference ref3 = new McpSchema.PromptReference(PromptReference.TYPE, "test-prompt", null); + + assertTrue(ref1.equals(ref2), "PromptReferences should be equal regardless of title"); + assertTrue(ref1.equals(ref3), "PromptReferences should be equal even when one has null title"); + assertTrue(ref2.equals(ref3), "PromptReferences should be equal even when one has null title"); + } + + @Test + void testHashCodeConsistency() { + McpSchema.PromptReference ref1 = new McpSchema.PromptReference(PromptReference.TYPE, "test-prompt", + "Test Title"); + McpSchema.PromptReference ref2 = new McpSchema.PromptReference(PromptReference.TYPE, "test-prompt", + "Different Title"); + + assertEquals(ref1.hashCode(), ref2.hashCode(), "Objects that are equal should have the same hash code"); + + // Call hashCode multiple times to ensure consistency + int hashCode1 = ref1.hashCode(); + int hashCode2 = ref1.hashCode(); + assertEquals(hashCode1, hashCode2, "Hash code should be consistent across multiple calls"); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/spec/json/gson/GsonMcpJsonMapper.java b/mcp-core/src/test/java/io/modelcontextprotocol/spec/json/gson/GsonMcpJsonMapper.java new file mode 100644 index 000000000..ef7cd2737 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/spec/json/gson/GsonMcpJsonMapper.java @@ -0,0 +1,97 @@ +package io.modelcontextprotocol.spec.json.gson; + +import com.google.gson.Gson; +import com.google.gson.GsonBuilder; +import com.google.gson.ToNumberPolicy; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.TypeRef; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; + +/** + * Test-only Gson-based implementation of McpJsonMapper. This lives under src/test/java so + * it doesn't affect production code or dependencies. + */ +public final class GsonMcpJsonMapper implements McpJsonMapper { + + private final Gson gson; + + public GsonMcpJsonMapper() { + this(new GsonBuilder().serializeNulls() + // Ensure numeric values in untyped (Object) fields preserve integral numbers + // as Long + .setObjectToNumberStrategy(ToNumberPolicy.LONG_OR_DOUBLE) + .setNumberToNumberStrategy(ToNumberPolicy.LONG_OR_DOUBLE) + .create()); + } + + public GsonMcpJsonMapper(Gson gson) { + if (gson == null) { + throw new IllegalArgumentException("Gson must not be null"); + } + this.gson = gson; + } + + public Gson getGson() { + return gson; + } + + @Override + public T readValue(String content, Class type) throws IOException { + try { + return gson.fromJson(content, type); + } + catch (Exception e) { + throw new IOException("Failed to deserialize JSON", e); + } + } + + @Override + public T readValue(byte[] content, Class type) throws IOException { + return readValue(new String(content, StandardCharsets.UTF_8), type); + } + + @Override + public T readValue(String content, TypeRef type) throws IOException { + try { + return gson.fromJson(content, type.getType()); + } + catch (Exception e) { + throw new IOException("Failed to deserialize JSON", e); + } + } + + @Override + public T readValue(byte[] content, TypeRef type) throws IOException { + return readValue(new String(content, StandardCharsets.UTF_8), type); + } + + @Override + public T convertValue(Object fromValue, Class type) { + String json = gson.toJson(fromValue); + return gson.fromJson(json, type); + } + + @Override + public T convertValue(Object fromValue, TypeRef type) { + String json = gson.toJson(fromValue); + return gson.fromJson(json, type.getType()); + } + + @Override + public String writeValueAsString(Object value) throws IOException { + try { + return gson.toJson(value); + } + catch (Exception e) { + throw new IOException("Failed to serialize to JSON", e); + } + } + + @Override + public byte[] writeValueAsBytes(Object value) throws IOException { + return writeValueAsString(value).getBytes(StandardCharsets.UTF_8); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/spec/json/gson/GsonMcpJsonMapperTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/spec/json/gson/GsonMcpJsonMapperTests.java new file mode 100644 index 000000000..498194d17 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/spec/json/gson/GsonMcpJsonMapperTests.java @@ -0,0 +1,135 @@ +package io.modelcontextprotocol.spec.json.gson; + +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.json.TypeRef; +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; + +class GsonMcpJsonMapperTests { + + record Person(String name, int age) { + } + + @Test + void roundTripSimplePojo() throws IOException { + var mapper = new GsonMcpJsonMapper(); + + var input = new Person("Alice", 30); + String json = mapper.writeValueAsString(input); + assertNotNull(json); + assertTrue(json.contains("\"Alice\"")); + assertTrue(json.contains("\"age\"")); + + var decoded = mapper.readValue(json, Person.class); + assertEquals(input, decoded); + + byte[] bytes = mapper.writeValueAsBytes(input); + assertNotNull(bytes); + var decodedFromBytes = mapper.readValue(bytes, Person.class); + assertEquals(input, decodedFromBytes); + } + + @Test + void readWriteParameterizedTypeWithTypeRef() throws IOException { + var mapper = new GsonMcpJsonMapper(); + String json = "[\"a\", \"b\", \"c\"]"; + + List list = mapper.readValue(json, new TypeRef>() { + }); + assertEquals(List.of("a", "b", "c"), list); + + String encoded = mapper.writeValueAsString(list); + assertTrue(encoded.startsWith("[")); + assertTrue(encoded.contains("\"a\"")); + } + + @Test + void convertValueMapToRecordAndParameterized() { + var mapper = new GsonMcpJsonMapper(); + Map src = Map.of("name", "Bob", "age", 42); + + // Convert to simple record + Person person = mapper.convertValue(src, Person.class); + assertEquals(new Person("Bob", 42), person); + + // Convert to parameterized Map + Map toMap = mapper.convertValue(person, new TypeRef>() { + }); + assertEquals("Bob", toMap.get("name")); + assertEquals(42.0, ((Number) toMap.get("age")).doubleValue(), 0.0); // Gson may + // emit double + // for + // primitives + } + + @Test + void deserializeJsonRpcMessageRequestUsingCustomMapper() throws IOException { + var mapper = new GsonMcpJsonMapper(); + + String json = """ + { + "jsonrpc": "2.0", + "id": 1, + "method": "ping", + "params": { "x": 1, "y": "z" } + } + """; + + var msg = McpSchema.deserializeJsonRpcMessage(mapper, json); + assertTrue(msg instanceof McpSchema.JSONRPCRequest); + + var req = (McpSchema.JSONRPCRequest) msg; + assertEquals("2.0", req.jsonrpc()); + assertEquals("ping", req.method()); + assertNotNull(req.id()); + assertEquals("1", req.id().toString()); + + assertNotNull(req.params()); + assertInstanceOf(Map.class, req.params()); + @SuppressWarnings("unchecked") + var params = (Map) req.params(); + assertEquals(1.0, ((Number) params.get("x")).doubleValue(), 0.0); + assertEquals("z", params.get("y")); + } + + @Test + void integrateWithMcpSchemaStaticMapperForStringParsing() { + var gsonMapper = new GsonMcpJsonMapper(); + + // Tool builder parsing of input/output schema strings + var tool = McpSchema.Tool.builder().name("echo").description("Echo tool").inputSchema(gsonMapper, """ + { + "type": "object", + "properties": { "x": { "type": "integer" } }, + "required": ["x"] + } + """).outputSchema(gsonMapper, """ + { + "type": "object", + "properties": { "y": { "type": "string" } } + } + """).build(); + + assertNotNull(tool.inputSchema()); + assertNotNull(tool.outputSchema()); + assertTrue(tool.outputSchema().containsKey("properties")); + + // CallToolRequest builder parsing of JSON arguments string + var call = McpSchema.CallToolRequest.builder().name("echo").arguments(gsonMapper, "{\"x\": 123}").build(); + + assertEquals("echo", call.name()); + assertNotNull(call.arguments()); + assertTrue(call.arguments().get("x") instanceof Number); + assertEquals(123.0, ((Number) call.arguments().get("x")).doubleValue(), 0.0); + + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/util/AssertTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/util/AssertTests.java similarity index 87% rename from mcp/src/test/java/io/modelcontextprotocol/util/AssertTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/util/AssertTests.java index 08555fef5..0038d4e1b 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/util/AssertTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/util/AssertTests.java @@ -8,7 +8,9 @@ import java.util.List; -import static org.junit.jupiter.api.Assertions.*; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; class AssertTests { diff --git a/mcp/src/test/java/io/modelcontextprotocol/util/KeepAliveSchedulerTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/util/KeepAliveSchedulerTests.java similarity index 98% rename from mcp/src/test/java/io/modelcontextprotocol/util/KeepAliveSchedulerTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/util/KeepAliveSchedulerTests.java index 4de9363c2..d5ef8a91c 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/util/KeepAliveSchedulerTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/util/KeepAliveSchedulerTests.java @@ -16,7 +16,7 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import com.fasterxml.jackson.core.type.TypeReference; +import io.modelcontextprotocol.json.TypeRef; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSession; @@ -259,7 +259,7 @@ private static class MockMcpSession implements McpSession { private boolean shouldFailPing = false; @Override - public Mono sendRequest(String method, Object requestParams, TypeReference typeRef) { + public Mono sendRequest(String method, Object requestParams, TypeRef typeRef) { if (McpSchema.METHOD_PING.equals(method)) { pingCount.incrementAndGet(); if (shouldFailPing) { diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/util/McpJsonMapperUtils.java b/mcp-core/src/test/java/io/modelcontextprotocol/util/McpJsonMapperUtils.java new file mode 100644 index 000000000..911506e01 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/util/McpJsonMapperUtils.java @@ -0,0 +1,12 @@ +package io.modelcontextprotocol.util; + +import io.modelcontextprotocol.json.McpJsonMapper; + +public final class McpJsonMapperUtils { + + private McpJsonMapperUtils() { + } + + public static final McpJsonMapper JSON_MAPPER = McpJsonMapper.getDefault(); + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/util/ToolsUtils.java b/mcp-core/src/test/java/io/modelcontextprotocol/util/ToolsUtils.java new file mode 100644 index 000000000..ce8755223 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/util/ToolsUtils.java @@ -0,0 +1,15 @@ +package io.modelcontextprotocol.util; + +import io.modelcontextprotocol.spec.McpSchema; + +import java.util.Collections; + +public final class ToolsUtils { + + private ToolsUtils() { + } + + public static final McpSchema.JsonSchema EMPTY_JSON_SCHEMA = new McpSchema.JsonSchema("object", + Collections.emptyMap(), null, null, null, null); + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/util/UtilsTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/util/UtilsTests.java similarity index 100% rename from mcp/src/test/java/io/modelcontextprotocol/util/UtilsTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/util/UtilsTests.java diff --git a/mcp/src/test/resources/logback.xml b/mcp-core/src/test/resources/logback.xml similarity index 100% rename from mcp/src/test/resources/logback.xml rename to mcp-core/src/test/resources/logback.xml diff --git a/mcp-json-jackson2/pom.xml b/mcp-json-jackson2/pom.xml new file mode 100644 index 000000000..de2ac58ce --- /dev/null +++ b/mcp-json-jackson2/pom.xml @@ -0,0 +1,79 @@ + + + 4.0.0 + + io.modelcontextprotocol.sdk + mcp-parent + 0.18.0-SNAPSHOT + + mcp-json-jackson2 + jar + Java MCP SDK JSON Jackson + Java MCP SDK JSON implementation based on Jackson + https://github.com/modelcontextprotocol/java-sdk + + https://github.com/modelcontextprotocol/java-sdk + git://github.com/modelcontextprotocol/java-sdk.git + git@github.com/modelcontextprotocol/java-sdk.git + + + + + org.apache.maven.plugins + maven-jar-plugin + + + + true + + + + + + + + + io.modelcontextprotocol.sdk + mcp-json + 0.18.0-SNAPSHOT + + + com.fasterxml.jackson.core + jackson-databind + ${jackson.version} + + + com.networknt + json-schema-validator + ${json-schema-validator.version} + + + + org.assertj + assertj-core + ${assert4j.version} + test + + + org.junit.jupiter + junit-jupiter-api + ${junit.version} + test + + + org.junit.jupiter + junit-jupiter-params + ${junit.version} + test + + + org.mockito + mockito-core + ${mockito.version} + test + + + + diff --git a/mcp-json-jackson2/src/main/java/io/modelcontextprotocol/json/jackson/JacksonMcpJsonMapper.java b/mcp-json-jackson2/src/main/java/io/modelcontextprotocol/json/jackson/JacksonMcpJsonMapper.java new file mode 100644 index 000000000..6aa2b4ebc --- /dev/null +++ b/mcp-json-jackson2/src/main/java/io/modelcontextprotocol/json/jackson/JacksonMcpJsonMapper.java @@ -0,0 +1,87 @@ +/* + * Copyright 2025 - 2025 the original author or authors. + */ + +package io.modelcontextprotocol.json.jackson; + +import com.fasterxml.jackson.databind.JavaType; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.TypeRef; + +import java.io.IOException; + +/** + * Jackson-based implementation of JsonMapper. Wraps a Jackson ObjectMapper but keeps the + * SDK decoupled from Jackson at the API level. + */ +public final class JacksonMcpJsonMapper implements McpJsonMapper { + + private final ObjectMapper objectMapper; + + /** + * Constructs a new JacksonMcpJsonMapper instance with the given ObjectMapper. + * @param objectMapper the ObjectMapper to be used for JSON serialization and + * deserialization. Must not be null. + * @throws IllegalArgumentException if the provided ObjectMapper is null. + */ + public JacksonMcpJsonMapper(ObjectMapper objectMapper) { + if (objectMapper == null) { + throw new IllegalArgumentException("ObjectMapper must not be null"); + } + this.objectMapper = objectMapper; + } + + /** + * Returns the underlying Jackson {@link ObjectMapper} used for JSON serialization and + * deserialization. + * @return the ObjectMapper instance + */ + public ObjectMapper getObjectMapper() { + return objectMapper; + } + + @Override + public T readValue(String content, Class type) throws IOException { + return objectMapper.readValue(content, type); + } + + @Override + public T readValue(byte[] content, Class type) throws IOException { + return objectMapper.readValue(content, type); + } + + @Override + public T readValue(String content, TypeRef type) throws IOException { + JavaType javaType = objectMapper.getTypeFactory().constructType(type.getType()); + return objectMapper.readValue(content, javaType); + } + + @Override + public T readValue(byte[] content, TypeRef type) throws IOException { + JavaType javaType = objectMapper.getTypeFactory().constructType(type.getType()); + return objectMapper.readValue(content, javaType); + } + + @Override + public T convertValue(Object fromValue, Class type) { + return objectMapper.convertValue(fromValue, type); + } + + @Override + public T convertValue(Object fromValue, TypeRef type) { + JavaType javaType = objectMapper.getTypeFactory().constructType(type.getType()); + return objectMapper.convertValue(fromValue, javaType); + } + + @Override + public String writeValueAsString(Object value) throws IOException { + return objectMapper.writeValueAsString(value); + } + + @Override + public byte[] writeValueAsBytes(Object value) throws IOException { + return objectMapper.writeValueAsBytes(value); + } + +} diff --git a/mcp-json-jackson2/src/main/java/io/modelcontextprotocol/json/jackson/JacksonMcpJsonMapperSupplier.java b/mcp-json-jackson2/src/main/java/io/modelcontextprotocol/json/jackson/JacksonMcpJsonMapperSupplier.java new file mode 100644 index 000000000..0e79c3e0e --- /dev/null +++ b/mcp-json-jackson2/src/main/java/io/modelcontextprotocol/json/jackson/JacksonMcpJsonMapperSupplier.java @@ -0,0 +1,32 @@ +/* + * Copyright 2025 - 2025 the original author or authors. + */ + +package io.modelcontextprotocol.json.jackson; + +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.McpJsonMapperSupplier; + +/** + * A supplier of {@link McpJsonMapper} instances that uses the Jackson library for JSON + * serialization and deserialization. + *

+ * This implementation provides a {@link McpJsonMapper} backed by a Jackson + * {@link com.fasterxml.jackson.databind.ObjectMapper}. + */ +public class JacksonMcpJsonMapperSupplier implements McpJsonMapperSupplier { + + /** + * Returns a new instance of {@link McpJsonMapper} that uses the Jackson library for + * JSON serialization and deserialization. + *

+ * The returned {@link McpJsonMapper} is backed by a new instance of + * {@link com.fasterxml.jackson.databind.ObjectMapper}. + * @return a new {@link McpJsonMapper} instance + */ + @Override + public McpJsonMapper get() { + return new JacksonMcpJsonMapper(new com.fasterxml.jackson.databind.ObjectMapper()); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultJsonSchemaValidator.java b/mcp-json-jackson2/src/main/java/io/modelcontextprotocol/json/schema/jackson/DefaultJsonSchemaValidator.java similarity index 66% rename from mcp/src/main/java/io/modelcontextprotocol/spec/DefaultJsonSchemaValidator.java rename to mcp-json-jackson2/src/main/java/io/modelcontextprotocol/json/schema/jackson/DefaultJsonSchemaValidator.java index f4bdc02eb..1ff28cb80 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultJsonSchemaValidator.java +++ b/mcp-json-jackson2/src/main/java/io/modelcontextprotocol/json/schema/jackson/DefaultJsonSchemaValidator.java @@ -1,26 +1,22 @@ /* * Copyright 2024-2024 the original author or authors. */ +package io.modelcontextprotocol.json.schema.jackson; -package io.modelcontextprotocol.spec; - +import java.util.List; import java.util.Map; -import java.util.Set; import java.util.concurrent.ConcurrentHashMap; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.databind.node.ObjectNode; -import com.networknt.schema.JsonSchema; -import com.networknt.schema.JsonSchemaFactory; -import com.networknt.schema.SpecVersion; -import com.networknt.schema.ValidationMessage; - -import io.modelcontextprotocol.util.Assert; +import com.networknt.schema.Schema; +import com.networknt.schema.SchemaRegistry; +import com.networknt.schema.Error; +import com.networknt.schema.dialect.Dialects; +import io.modelcontextprotocol.json.schema.JsonSchemaValidator; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * Default implementation of the {@link JsonSchemaValidator} interface. This class @@ -35,10 +31,10 @@ public class DefaultJsonSchemaValidator implements JsonSchemaValidator { private final ObjectMapper objectMapper; - private final JsonSchemaFactory schemaFactory; + private final SchemaRegistry schemaFactory; // TODO: Implement a strategy to purge the cache (TTL, size limit, etc.) - private final ConcurrentHashMap schemaCache; + private final ConcurrentHashMap schemaCache; public DefaultJsonSchemaValidator() { this(new ObjectMapper()); @@ -46,21 +42,27 @@ public DefaultJsonSchemaValidator() { public DefaultJsonSchemaValidator(ObjectMapper objectMapper) { this.objectMapper = objectMapper; - this.schemaFactory = JsonSchemaFactory.getInstance(SpecVersion.VersionFlag.V202012); + this.schemaFactory = SchemaRegistry.withDialect(Dialects.getDraft202012()); this.schemaCache = new ConcurrentHashMap<>(); } @Override - public ValidationResponse validate(Map schema, Map structuredContent) { + public ValidationResponse validate(Map schema, Object structuredContent) { - Assert.notNull(schema, "Schema must not be null"); - Assert.notNull(structuredContent, "Structured content must not be null"); + if (schema == null) { + throw new IllegalArgumentException("Schema must not be null"); + } + if (structuredContent == null) { + throw new IllegalArgumentException("Structured content must not be null"); + } try { - JsonNode jsonStructuredOutput = this.objectMapper.valueToTree(structuredContent); + JsonNode jsonStructuredOutput = (structuredContent instanceof String) + ? this.objectMapper.readTree((String) structuredContent) + : this.objectMapper.valueToTree(structuredContent); - Set validationResult = this.getOrCreateJsonSchema(schema).validate(jsonStructuredOutput); + List validationResult = this.getOrCreateJsonSchema(schema).validate(jsonStructuredOutput); // Check if validation passed if (!validationResult.isEmpty()) { @@ -83,36 +85,36 @@ public ValidationResponse validate(Map schema, Map schema) throws JsonProcessingException { + private Schema getOrCreateJsonSchema(Map schema) throws JsonProcessingException { // Generate cache key based on schema content String cacheKey = this.generateCacheKey(schema); // Try to get from cache first - JsonSchema cachedSchema = this.schemaCache.get(cacheKey); + Schema cachedSchema = this.schemaCache.get(cacheKey); if (cachedSchema != null) { return cachedSchema; } // Create new schema if not in cache - JsonSchema newSchema = this.createJsonSchema(schema); + Schema newSchema = this.createJsonSchema(schema); // Cache the schema - JsonSchema existingSchema = this.schemaCache.putIfAbsent(cacheKey, newSchema); + Schema existingSchema = this.schemaCache.putIfAbsent(cacheKey, newSchema); return existingSchema != null ? existingSchema : newSchema; } /** - * Creates a new JsonSchema from the given schema map. + * Creates a new Schema from the given schema map. * @param schema the schema map - * @return the compiled JsonSchema + * @return the compiled Schema * @throws JsonProcessingException if schema processing fails */ - private JsonSchema createJsonSchema(Map schema) throws JsonProcessingException { + private Schema createJsonSchema(Map schema) throws JsonProcessingException { // Convert schema map directly to JsonNode (more efficient than string // serialization) JsonNode schemaNode = this.objectMapper.valueToTree(schema); @@ -123,17 +125,6 @@ private JsonSchema createJsonSchema(Map schema) throws JsonProce }; } - // Handle additionalProperties setting - if (schemaNode.isObject()) { - ObjectNode objectSchemaNode = (ObjectNode) schemaNode; - if (!objectSchemaNode.has("additionalProperties")) { - // Clone the node before modification to avoid mutating the original - objectSchemaNode = objectSchemaNode.deepCopy(); - objectSchemaNode.put("additionalProperties", false); - schemaNode = objectSchemaNode; - } - } - return this.schemaFactory.getSchema(schemaNode); } diff --git a/mcp-json-jackson2/src/main/java/io/modelcontextprotocol/json/schema/jackson/JacksonJsonSchemaValidatorSupplier.java b/mcp-json-jackson2/src/main/java/io/modelcontextprotocol/json/schema/jackson/JacksonJsonSchemaValidatorSupplier.java new file mode 100644 index 000000000..86153a538 --- /dev/null +++ b/mcp-json-jackson2/src/main/java/io/modelcontextprotocol/json/schema/jackson/JacksonJsonSchemaValidatorSupplier.java @@ -0,0 +1,29 @@ +/* + * Copyright 2025 - 2025 the original author or authors. + */ + +package io.modelcontextprotocol.json.schema.jackson; + +import io.modelcontextprotocol.json.schema.JsonSchemaValidator; +import io.modelcontextprotocol.json.schema.JsonSchemaValidatorSupplier; + +/** + * A concrete implementation of {@link JsonSchemaValidatorSupplier} that provides a + * {@link JsonSchemaValidator} instance based on the Jackson library. + * + * @see JsonSchemaValidatorSupplier + * @see JsonSchemaValidator + */ +public class JacksonJsonSchemaValidatorSupplier implements JsonSchemaValidatorSupplier { + + /** + * Returns a new instance of {@link JsonSchemaValidator} that uses the Jackson library + * for JSON schema validation. + * @return A {@link JsonSchemaValidator} instance. + */ + @Override + public JsonSchemaValidator get() { + return new DefaultJsonSchemaValidator(); + } + +} diff --git a/mcp-json-jackson2/src/main/resources/META-INF/services/io.modelcontextprotocol.json.McpJsonMapperSupplier b/mcp-json-jackson2/src/main/resources/META-INF/services/io.modelcontextprotocol.json.McpJsonMapperSupplier new file mode 100644 index 000000000..8ea66d698 --- /dev/null +++ b/mcp-json-jackson2/src/main/resources/META-INF/services/io.modelcontextprotocol.json.McpJsonMapperSupplier @@ -0,0 +1 @@ +io.modelcontextprotocol.json.jackson.JacksonMcpJsonMapperSupplier \ No newline at end of file diff --git a/mcp-json-jackson2/src/main/resources/META-INF/services/io.modelcontextprotocol.json.schema.JsonSchemaValidatorSupplier b/mcp-json-jackson2/src/main/resources/META-INF/services/io.modelcontextprotocol.json.schema.JsonSchemaValidatorSupplier new file mode 100644 index 000000000..0fb0b7e5a --- /dev/null +++ b/mcp-json-jackson2/src/main/resources/META-INF/services/io.modelcontextprotocol.json.schema.JsonSchemaValidatorSupplier @@ -0,0 +1 @@ +io.modelcontextprotocol.json.schema.jackson.JacksonJsonSchemaValidatorSupplier \ No newline at end of file diff --git a/mcp/src/test/java/io/modelcontextprotocol/spec/DefaultJsonSchemaValidatorTests.java b/mcp-json-jackson2/src/test/java/io/modelcontextprotocol/json/DefaultJsonSchemaValidatorTests.java similarity index 83% rename from mcp/src/test/java/io/modelcontextprotocol/spec/DefaultJsonSchemaValidatorTests.java rename to mcp-json-jackson2/src/test/java/io/modelcontextprotocol/json/DefaultJsonSchemaValidatorTests.java index 30158543d..7642f0480 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/spec/DefaultJsonSchemaValidatorTests.java +++ b/mcp-json-jackson2/src/test/java/io/modelcontextprotocol/json/DefaultJsonSchemaValidatorTests.java @@ -2,7 +2,7 @@ * Copyright 2024-2024 the original author or authors. */ -package io.modelcontextprotocol.spec; +package io.modelcontextprotocol.json; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; @@ -13,9 +13,11 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.when; +import java.util.List; import java.util.Map; import java.util.stream.Stream; +import io.modelcontextprotocol.json.schema.jackson.DefaultJsonSchemaValidator; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; @@ -27,7 +29,7 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.spec.JsonSchemaValidator.ValidationResponse; +import io.modelcontextprotocol.json.schema.JsonSchemaValidator.ValidationResponse; /** * Tests for {@link DefaultJsonSchemaValidator}. @@ -63,6 +65,16 @@ private Map toMap(String json) { } } + private List> toListMap(String json) { + try { + return objectMapper.readValue(json, new TypeReference>>() { + }); + } + catch (Exception e) { + throw new RuntimeException("Failed to parse JSON: " + json, e); + } + } + @Test void testDefaultConstructor() { DefaultJsonSchemaValidator defaultValidator = new DefaultJsonSchemaValidator(); @@ -197,6 +209,74 @@ void testValidateWithValidArraySchema() { assertNull(response.errorMessage()); } + @Test + void testValidateWithValidArraySchemaTopLevelArray() { + String schemaJson = """ + { + "$schema" : "https://json-schema.org/draft/2020-12/schema", + "type" : "array", + "items" : { + "type" : "object", + "properties" : { + "city" : { + "type" : "string" + }, + "summary" : { + "type" : "string" + }, + "temperatureC" : { + "type" : "number", + "format" : "float" + } + }, + "required" : [ "city", "summary", "temperatureC" ] + }, + "additionalProperties" : false + } + """; + + String contentJson = """ + [ + { + "city": "London", + "summary": "Generally mild with frequent rainfall. Winters are cool and damp, summers are warm but rarely hot. Cloudy conditions are common throughout the year.", + "temperatureC": 11.3 + }, + { + "city": "New York", + "summary": "Four distinct seasons with hot and humid summers, cold winters with snow, and mild springs and autumns. Precipitation is fairly evenly distributed throughout the year.", + "temperatureC": 12.8 + }, + { + "city": "San Francisco", + "summary": "Mild year-round with a distinctive Mediterranean climate. Famous for summer fog, mild winters, and little temperature variation throughout the year. Very little rainfall in summer months.", + "temperatureC": 14.6 + }, + { + "city": "Tokyo", + "summary": "Humid subtropical climate with hot, wet summers and mild winters. Experiences a rainy season in early summer and occasional typhoons in late summer to early autumn.", + "temperatureC": 15.4 + } + ] + """; + + Map schema = toMap(schemaJson); + + // Validate as JSON string + ValidationResponse response = validator.validate(schema, contentJson); + + assertTrue(response.valid()); + assertNull(response.errorMessage()); + + List> structuredContent = toListMap(contentJson); + + // Validate as List> + response = validator.validate(schema, structuredContent); + + assertTrue(response.valid()); + assertNull(response.errorMessage()); + } + @Test void testValidateWithInvalidTypeSchema() { String schemaJson = """ @@ -265,7 +345,8 @@ void testValidateWithAdditionalPropertiesNotAllowed() { "properties": { "name": {"type": "string"} }, - "required": ["name"] + "required": ["name"], + "additionalProperties": false } """; @@ -315,6 +396,35 @@ void testValidateWithAdditionalPropertiesExplicitlyAllowed() { assertNull(response.errorMessage()); } + @Test + void testValidateWithDefaultAdditionalProperties() { + String schemaJson = """ + { + "type": "object", + "properties": { + "name": {"type": "string"} + }, + "required": ["name"], + "additionalProperties": true + } + """; + + String contentJson = """ + { + "name": "John Doe", + "extraField": "should be allowed" + } + """; + + Map schema = toMap(schemaJson); + Map structuredContent = toMap(contentJson); + + ValidationResponse response = validator.validate(schema, structuredContent); + + assertTrue(response.valid()); + assertNull(response.errorMessage()); + } + @Test void testValidateWithAdditionalPropertiesExplicitlyDisallowed() { String schemaJson = """ diff --git a/mcp-json/pom.xml b/mcp-json/pom.xml new file mode 100644 index 000000000..2cbcf3516 --- /dev/null +++ b/mcp-json/pom.xml @@ -0,0 +1,39 @@ + + + 4.0.0 + + io.modelcontextprotocol.sdk + mcp-parent + 0.18.0-SNAPSHOT + + mcp-json + jar + Java MCP SDK JSON Support + Java MCP SDK JSON Support API + https://github.com/modelcontextprotocol/java-sdk + + https://github.com/modelcontextprotocol/java-sdk + git://github.com/modelcontextprotocol/java-sdk.git + git@github.com/modelcontextprotocol/java-sdk.git + + + + + org.apache.maven.plugins + maven-jar-plugin + + + + true + + + + + + + + + + diff --git a/mcp-json/src/main/java/io/modelcontextprotocol/json/McpJsonInternal.java b/mcp-json/src/main/java/io/modelcontextprotocol/json/McpJsonInternal.java new file mode 100644 index 000000000..31930ab33 --- /dev/null +++ b/mcp-json/src/main/java/io/modelcontextprotocol/json/McpJsonInternal.java @@ -0,0 +1,84 @@ +/* + * Copyright 2025 - 2025 the original author or authors. + */ + +package io.modelcontextprotocol.json; + +import java.util.ServiceLoader; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Stream; + +/** + * Utility class for creating a default {@link McpJsonMapper} instance. This class + * provides a single method to create a default mapper using the {@link ServiceLoader} + * mechanism. + */ +final class McpJsonInternal { + + private static McpJsonMapper defaultJsonMapper = null; + + /** + * Returns the cached default {@link McpJsonMapper} instance. If the default mapper + * has not been created yet, it will be initialized using the + * {@link #createDefaultMapper()} method. + * @return the default {@link McpJsonMapper} instance + * @throws IllegalStateException if no default {@link McpJsonMapper} implementation is + * found + */ + static McpJsonMapper getDefaultMapper() { + if (defaultJsonMapper == null) { + defaultJsonMapper = McpJsonInternal.createDefaultMapper(); + } + return defaultJsonMapper; + } + + /** + * Creates a default {@link McpJsonMapper} instance using the {@link ServiceLoader} + * mechanism. The default mapper is resolved by loading the first available + * {@link McpJsonMapperSupplier} implementation on the classpath. + * @return the default {@link McpJsonMapper} instance + * @throws IllegalStateException if no default {@link McpJsonMapper} implementation is + * found + */ + static McpJsonMapper createDefaultMapper() { + AtomicReference ex = new AtomicReference<>(); + return ServiceLoader.load(McpJsonMapperSupplier.class).stream().flatMap(p -> { + try { + McpJsonMapperSupplier supplier = p.get(); + return Stream.ofNullable(supplier); + } + catch (Exception e) { + addException(ex, e); + return Stream.empty(); + } + }).flatMap(jsonMapperSupplier -> { + try { + return Stream.ofNullable(jsonMapperSupplier.get()); + } + catch (Exception e) { + addException(ex, e); + return Stream.empty(); + } + }).findFirst().orElseThrow(() -> { + if (ex.get() != null) { + return ex.get(); + } + else { + return new IllegalStateException("No default McpJsonMapper implementation found"); + } + }); + } + + private static void addException(AtomicReference ref, Exception toAdd) { + ref.updateAndGet(existing -> { + if (existing == null) { + return new IllegalStateException("Failed to initialize default McpJsonMapper", toAdd); + } + else { + existing.addSuppressed(toAdd); + return existing; + } + }); + } + +} diff --git a/mcp-json/src/main/java/io/modelcontextprotocol/json/McpJsonMapper.java b/mcp-json/src/main/java/io/modelcontextprotocol/json/McpJsonMapper.java new file mode 100644 index 000000000..1e30cad16 --- /dev/null +++ b/mcp-json/src/main/java/io/modelcontextprotocol/json/McpJsonMapper.java @@ -0,0 +1,110 @@ +/* + * Copyright 2025 - 2025 the original author or authors. + */ + +package io.modelcontextprotocol.json; + +import java.io.IOException; + +/** + * Abstraction for JSON serialization/deserialization to decouple the SDK from any + * specific JSON library. A default implementation backed by Jackson is provided in + * io.modelcontextprotocol.spec.json.jackson.JacksonJsonMapper. + */ +public interface McpJsonMapper { + + /** + * Deserialize JSON string into a target type. + * @param content JSON as String + * @param type target class + * @return deserialized instance + * @param generic type + * @throws IOException on parse errors + */ + T readValue(String content, Class type) throws IOException; + + /** + * Deserialize JSON bytes into a target type. + * @param content JSON as bytes + * @param type target class + * @return deserialized instance + * @param generic type + * @throws IOException on parse errors + */ + T readValue(byte[] content, Class type) throws IOException; + + /** + * Deserialize JSON string into a parameterized target type. + * @param content JSON as String + * @param type parameterized type reference + * @return deserialized instance + * @param generic type + * @throws IOException on parse errors + */ + T readValue(String content, TypeRef type) throws IOException; + + /** + * Deserialize JSON bytes into a parameterized target type. + * @param content JSON as bytes + * @param type parameterized type reference + * @return deserialized instance + * @param generic type + * @throws IOException on parse errors + */ + T readValue(byte[] content, TypeRef type) throws IOException; + + /** + * Convert a value to a given type, useful for mapping nested JSON structures. + * @param fromValue source value + * @param type target class + * @return converted value + * @param generic type + */ + T convertValue(Object fromValue, Class type); + + /** + * Convert a value to a given parameterized type. + * @param fromValue source value + * @param type target type reference + * @return converted value + * @param generic type + */ + T convertValue(Object fromValue, TypeRef type); + + /** + * Serialize an object to JSON string. + * @param value object to serialize + * @return JSON as String + * @throws IOException on serialization errors + */ + String writeValueAsString(Object value) throws IOException; + + /** + * Serialize an object to JSON bytes. + * @param value object to serialize + * @return JSON as bytes + * @throws IOException on serialization errors + */ + byte[] writeValueAsBytes(Object value) throws IOException; + + /** + * Returns the default {@link McpJsonMapper}. + * @return The default {@link McpJsonMapper} + * @throws IllegalStateException If no {@link McpJsonMapper} implementation exists on + * the classpath. + */ + static McpJsonMapper getDefault() { + return McpJsonInternal.getDefaultMapper(); + } + + /** + * Creates a new default {@link McpJsonMapper}. + * @return The default {@link McpJsonMapper} + * @throws IllegalStateException If no {@link McpJsonMapper} implementation exists on + * the classpath. + */ + static McpJsonMapper createDefault() { + return McpJsonInternal.createDefaultMapper(); + } + +} diff --git a/mcp-json/src/main/java/io/modelcontextprotocol/json/McpJsonMapperSupplier.java b/mcp-json/src/main/java/io/modelcontextprotocol/json/McpJsonMapperSupplier.java new file mode 100644 index 000000000..619f96040 --- /dev/null +++ b/mcp-json/src/main/java/io/modelcontextprotocol/json/McpJsonMapperSupplier.java @@ -0,0 +1,14 @@ +/* + * Copyright 2025 - 2025 the original author or authors. + */ + +package io.modelcontextprotocol.json; + +import java.util.function.Supplier; + +/** + * Strategy interface for resolving a {@link McpJsonMapper}. + */ +public interface McpJsonMapperSupplier extends Supplier { + +} diff --git a/mcp-json/src/main/java/io/modelcontextprotocol/json/TypeRef.java b/mcp-json/src/main/java/io/modelcontextprotocol/json/TypeRef.java new file mode 100644 index 000000000..ab37b43f3 --- /dev/null +++ b/mcp-json/src/main/java/io/modelcontextprotocol/json/TypeRef.java @@ -0,0 +1,44 @@ +/* + * Copyright 2025 - 2025 the original author or authors. + */ + +package io.modelcontextprotocol.json; + +import java.lang.reflect.ParameterizedType; +import java.lang.reflect.Type; + +/** + * Captures generic type information at runtime for parameterized JSON (de)serialization. + * Usage: TypeRef<List<Foo>> ref = new TypeRef<>(){}; + */ +public abstract class TypeRef { + + private final Type type; + + /** + * Constructs a new TypeRef instance, capturing the generic type information of the + * subclass. This constructor should be called from an anonymous subclass to capture + * the actual type arguments. For example:

+	 * TypeRef<List<Foo>> ref = new TypeRef<>(){};
+	 * 
+ * @throws IllegalStateException if TypeRef is not subclassed with actual type + * information + */ + protected TypeRef() { + Type superClass = getClass().getGenericSuperclass(); + if (superClass instanceof Class) { + throw new IllegalStateException("TypeRef constructed without actual type information"); + } + this.type = ((ParameterizedType) superClass).getActualTypeArguments()[0]; + } + + /** + * Returns the captured type information. + * @return the Type representing the actual type argument captured by this TypeRef + * instance + */ + public Type getType() { + return type; + } + +} diff --git a/mcp-json/src/main/java/io/modelcontextprotocol/json/schema/JsonSchemaInternal.java b/mcp-json/src/main/java/io/modelcontextprotocol/json/schema/JsonSchemaInternal.java new file mode 100644 index 000000000..2497e7f80 --- /dev/null +++ b/mcp-json/src/main/java/io/modelcontextprotocol/json/schema/JsonSchemaInternal.java @@ -0,0 +1,83 @@ +/* + * Copyright 2025 - 2025 the original author or authors. + */ + +package io.modelcontextprotocol.json.schema; + +import java.util.ServiceLoader; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Stream; + +/** + * Internal utility class for creating a default {@link JsonSchemaValidator} instance. + * This class uses the {@link ServiceLoader} to discover and instantiate a + * {@link JsonSchemaValidatorSupplier} implementation. + */ +final class JsonSchemaInternal { + + private static JsonSchemaValidator defaultValidator = null; + + /** + * Returns the default {@link JsonSchemaValidator} instance. If the default validator + * has not been initialized, it will be created using the {@link ServiceLoader} to + * discover and instantiate a {@link JsonSchemaValidatorSupplier} implementation. + * @return The default {@link JsonSchemaValidator} instance. + * @throws IllegalStateException If no {@link JsonSchemaValidatorSupplier} + * implementation exists on the classpath or if an error occurs during instantiation. + */ + static JsonSchemaValidator getDefaultValidator() { + if (defaultValidator == null) { + defaultValidator = JsonSchemaInternal.createDefaultValidator(); + } + return defaultValidator; + } + + /** + * Creates a default {@link JsonSchemaValidator} instance by loading a + * {@link JsonSchemaValidatorSupplier} implementation using the {@link ServiceLoader}. + * @return A default {@link JsonSchemaValidator} instance. + * @throws IllegalStateException If no {@link JsonSchemaValidatorSupplier} + * implementation is found or if an error occurs during instantiation. + */ + static JsonSchemaValidator createDefaultValidator() { + AtomicReference ex = new AtomicReference<>(); + return ServiceLoader.load(JsonSchemaValidatorSupplier.class).stream().flatMap(p -> { + try { + JsonSchemaValidatorSupplier supplier = p.get(); + return Stream.ofNullable(supplier); + } + catch (Exception e) { + addException(ex, e); + return Stream.empty(); + } + }).flatMap(jsonMapperSupplier -> { + try { + return Stream.of(jsonMapperSupplier.get()); + } + catch (Exception e) { + addException(ex, e); + return Stream.empty(); + } + }).findFirst().orElseThrow(() -> { + if (ex.get() != null) { + return ex.get(); + } + else { + return new IllegalStateException("No default JsonSchemaValidatorSupplier implementation found"); + } + }); + } + + private static void addException(AtomicReference ref, Exception toAdd) { + ref.updateAndGet(existing -> { + if (existing == null) { + return new IllegalStateException("Failed to initialize default JsonSchemaValidatorSupplier", toAdd); + } + else { + existing.addSuppressed(toAdd); + return existing; + } + }); + } + +} diff --git a/mcp-json/src/main/java/io/modelcontextprotocol/json/schema/JsonSchemaValidator.java b/mcp-json/src/main/java/io/modelcontextprotocol/json/schema/JsonSchemaValidator.java new file mode 100644 index 000000000..8e35c0237 --- /dev/null +++ b/mcp-json/src/main/java/io/modelcontextprotocol/json/schema/JsonSchemaValidator.java @@ -0,0 +1,64 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ +package io.modelcontextprotocol.json.schema; + +import java.util.Map; + +/** + * Interface for validating structured content against a JSON schema. This interface + * defines a method to validate structured content based on the provided output schema. + * + * @author Christian Tzolov + */ +public interface JsonSchemaValidator { + + /** + * Represents the result of a validation operation. + * + * @param valid Indicates whether the validation was successful. + * @param errorMessage An error message if the validation failed, otherwise null. + * @param jsonStructuredOutput The text structured content in JSON format if the + * validation was successful, otherwise null. + */ + record ValidationResponse(boolean valid, String errorMessage, String jsonStructuredOutput) { + + public static ValidationResponse asValid(String jsonStructuredOutput) { + return new ValidationResponse(true, null, jsonStructuredOutput); + } + + public static ValidationResponse asInvalid(String message) { + return new ValidationResponse(false, message, null); + } + } + + /** + * Validates the structured content against the provided JSON schema. + * @param schema The JSON schema to validate against. + * @param structuredContent The structured content to validate. + * @return A ValidationResponse indicating whether the validation was successful or + * not. + */ + ValidationResponse validate(Map schema, Object structuredContent); + + /** + * Creates the default {@link JsonSchemaValidator}. + * @return The default {@link JsonSchemaValidator} + * @throws IllegalStateException If no {@link JsonSchemaValidator} implementation + * exists on the classpath. + */ + static JsonSchemaValidator createDefault() { + return JsonSchemaInternal.createDefaultValidator(); + } + + /** + * Returns the default {@link JsonSchemaValidator}. + * @return The default {@link JsonSchemaValidator} + * @throws IllegalStateException If no {@link JsonSchemaValidator} implementation + * exists on the classpath. + */ + static JsonSchemaValidator getDefault() { + return JsonSchemaInternal.getDefaultValidator(); + } + +} diff --git a/mcp-json/src/main/java/io/modelcontextprotocol/json/schema/JsonSchemaValidatorSupplier.java b/mcp-json/src/main/java/io/modelcontextprotocol/json/schema/JsonSchemaValidatorSupplier.java new file mode 100644 index 000000000..6f69169a0 --- /dev/null +++ b/mcp-json/src/main/java/io/modelcontextprotocol/json/schema/JsonSchemaValidatorSupplier.java @@ -0,0 +1,19 @@ +/* + * Copyright 2025 - 2025 the original author or authors. + */ + +package io.modelcontextprotocol.json.schema; + +import java.util.function.Supplier; + +/** + * A supplier interface that provides a {@link JsonSchemaValidator} instance. + * Implementations of this interface are expected to return a new or cached instance of + * {@link JsonSchemaValidator} when {@link #get()} is invoked. + * + * @see JsonSchemaValidator + * @see Supplier + */ +public interface JsonSchemaValidatorSupplier extends Supplier { + +} diff --git a/mcp-spring/mcp-spring-webflux/pom.xml b/mcp-spring/mcp-spring-webflux/pom.xml index 300d518e7..f1737a477 100644 --- a/mcp-spring/mcp-spring-webflux/pom.xml +++ b/mcp-spring/mcp-spring-webflux/pom.xml @@ -6,7 +6,7 @@ io.modelcontextprotocol.sdk mcp-parent - 0.12.0-SNAPSHOT + 0.18.0-SNAPSHOT ../../pom.xml mcp-spring-webflux @@ -22,16 +22,22 @@ - + + io.modelcontextprotocol.sdk + mcp-json-jackson2 + 0.18.0-SNAPSHOT + + + io.modelcontextprotocol.sdk mcp - 0.12.0-SNAPSHOT + 0.18.0-SNAPSHOT io.modelcontextprotocol.sdk mcp-test - 0.12.0-SNAPSHOT + 0.18.0-SNAPSHOT test diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java index 853aed2bf..a8a4762c2 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java @@ -5,6 +5,8 @@ package io.modelcontextprotocol.client.transport; import java.io.IOException; +import java.util.Collections; +import java.util.Comparator; import java.util.List; import java.util.Optional; import java.util.concurrent.atomic.AtomicReference; @@ -22,9 +24,10 @@ import org.springframework.web.reactive.function.client.WebClient; import org.springframework.web.reactive.function.client.WebClientResponseException; -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; - +import io.modelcontextprotocol.client.McpAsyncClient; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.TypeRef; +import io.modelcontextprotocol.spec.ClosedMcpTransportSession; import io.modelcontextprotocol.spec.DefaultMcpTransportSession; import io.modelcontextprotocol.spec.DefaultMcpTransportStream; import io.modelcontextprotocol.spec.HttpHeaders; @@ -75,8 +78,6 @@ public class WebClientStreamableHttpTransport implements McpClientTransport { private static final Logger logger = LoggerFactory.getLogger(WebClientStreamableHttpTransport.class); - private static final String MCP_PROTOCOL_VERSION = ProtocolVersions.MCP_2025_03_26; - private static final String DEFAULT_ENDPOINT = "/mcp"; /** @@ -88,7 +89,7 @@ public class WebClientStreamableHttpTransport implements McpClientTransport { private static final ParameterizedTypeReference> PARAMETERIZED_TYPE_REF = new ParameterizedTypeReference<>() { }; - private final ObjectMapper objectMapper; + private final McpJsonMapper jsonMapper; private final WebClient webClient; @@ -98,25 +99,35 @@ public class WebClientStreamableHttpTransport implements McpClientTransport { private final boolean resumableStreams; - private final AtomicReference activeSession = new AtomicReference<>(); + private final AtomicReference> activeSession = new AtomicReference<>(); private final AtomicReference, Mono>> handler = new AtomicReference<>(); private final AtomicReference> exceptionHandler = new AtomicReference<>(); - private WebClientStreamableHttpTransport(ObjectMapper objectMapper, WebClient.Builder webClientBuilder, - String endpoint, boolean resumableStreams, boolean openConnectionOnStartup) { - this.objectMapper = objectMapper; + private final List supportedProtocolVersions; + + private final String latestSupportedProtocolVersion; + + private WebClientStreamableHttpTransport(McpJsonMapper jsonMapper, WebClient.Builder webClientBuilder, + String endpoint, boolean resumableStreams, boolean openConnectionOnStartup, + List supportedProtocolVersions) { + this.jsonMapper = jsonMapper; this.webClient = webClientBuilder.build(); this.endpoint = endpoint; this.resumableStreams = resumableStreams; this.openConnectionOnStartup = openConnectionOnStartup; this.activeSession.set(createTransportSession()); + this.supportedProtocolVersions = List.copyOf(supportedProtocolVersions); + this.latestSupportedProtocolVersion = this.supportedProtocolVersions.stream() + .sorted(Comparator.reverseOrder()) + .findFirst() + .get(); } @Override public List protocolVersions() { - return List.of(ProtocolVersions.MCP_2024_11_05, ProtocolVersions.MCP_2025_03_26); + return supportedProtocolVersions; } /** @@ -142,12 +153,12 @@ public Mono connect(Function, Mono createTransportSession() { Function> onClose = sessionId -> sessionId == null ? Mono.empty() : webClient.delete() .uri(this.endpoint) .header(HttpHeaders.MCP_SESSION_ID, sessionId) - .header(HttpHeaders.PROTOCOL_VERSION, MCP_PROTOCOL_VERSION) + .header(HttpHeaders.PROTOCOL_VERSION, this.latestSupportedProtocolVersion) .retrieve() .toBodilessEntity() .onErrorComplete(e -> { @@ -158,6 +169,14 @@ private DefaultMcpTransportSession createTransportSession() { return new DefaultMcpTransportSession(onClose); } + private McpTransportSession createClosedSession(McpTransportSession existingSession) { + var existingSessionId = Optional.ofNullable(existingSession) + .filter(session -> !(session instanceof ClosedMcpTransportSession)) + .flatMap(McpTransportSession::sessionId) + .orElse(null); + return new ClosedMcpTransportSession<>(existingSessionId); + } + @Override public void setExceptionHandler(Consumer handler) { logger.debug("Exception handler registered"); @@ -181,9 +200,9 @@ private void handleException(Throwable t) { public Mono closeGracefully() { return Mono.defer(() -> { logger.debug("Graceful close triggered"); - DefaultMcpTransportSession currentSession = this.activeSession.getAndSet(createTransportSession()); + McpTransportSession currentSession = this.activeSession.getAndUpdate(this::createClosedSession); if (currentSession != null) { - return currentSession.closeGracefully(); + return Mono.from(currentSession.closeGracefully()); } return Mono.empty(); }); @@ -207,7 +226,9 @@ private Mono reconnect(McpTransportStream stream) { Disposable connection = webClient.get() .uri(this.endpoint) .accept(MediaType.TEXT_EVENT_STREAM) - .header(HttpHeaders.PROTOCOL_VERSION, MCP_PROTOCOL_VERSION) + .header(HttpHeaders.PROTOCOL_VERSION, + ctx.getOrDefault(McpAsyncClient.NEGOTIATED_PROTOCOL_VERSION, + this.latestSupportedProtocolVersion)) .headers(httpHeaders -> { transportSession.sessionId().ifPresent(id -> httpHeaders.add(HttpHeaders.MCP_SESSION_ID, id)); if (stream != null) { @@ -270,10 +291,12 @@ public Mono sendMessage(McpSchema.JSONRPCMessage message) { final AtomicReference disposableRef = new AtomicReference<>(); final McpTransportSession transportSession = this.activeSession.get(); - Disposable connection = webClient.post() + Disposable connection = Flux.deferContextual(ctx -> webClient.post() .uri(this.endpoint) .accept(MediaType.APPLICATION_JSON, MediaType.TEXT_EVENT_STREAM) - .header(HttpHeaders.PROTOCOL_VERSION, MCP_PROTOCOL_VERSION) + .header(HttpHeaders.PROTOCOL_VERSION, + ctx.getOrDefault(McpAsyncClient.NEGOTIATED_PROTOCOL_VERSION, + this.latestSupportedProtocolVersion)) .headers(httpHeaders -> { transportSession.sessionId().ifPresent(id -> httpHeaders.add(HttpHeaders.MCP_SESSION_ID, id)); }) @@ -292,9 +315,10 @@ public Mono sendMessage(McpSchema.JSONRPCMessage message) { // 200 OK for notifications if (response.statusCode().is2xxSuccessful()) { Optional contentType = response.headers().contentType(); + long contentLength = response.headers().contentLength().orElse(-1); // Existing SDKs consume notifications with no response body nor // content type - if (contentType.isEmpty()) { + if (contentType.isEmpty() || contentLength == 0) { logger.trace("Message was successfully sent via POST for session {}", sessionRepresentation); // signal the caller that the message was successfully @@ -331,7 +355,7 @@ else if (mediaType.isCompatibleWith(MediaType.APPLICATION_JSON)) { } return this.extractError(response, sessionRepresentation); } - }) + })) .flatMap(jsonRpcMessage -> this.handler.get().apply(Mono.just(jsonRpcMessage))) .onErrorComplete(t -> { // handle the error first @@ -366,8 +390,7 @@ private Flux extractError(ClientResponse response, Str McpSchema.JSONRPCResponse.JSONRPCError jsonRpcError = null; Exception toPropagate; try { - McpSchema.JSONRPCResponse jsonRpcResponse = objectMapper.readValue(body, - McpSchema.JSONRPCResponse.class); + McpSchema.JSONRPCResponse jsonRpcResponse = jsonMapper.readValue(body, McpSchema.JSONRPCResponse.class); jsonRpcError = jsonRpcResponse.error(); toPropagate = jsonRpcError != null ? new McpError(jsonRpcError) : new McpTransportException("Can't parse the jsonResponse " + jsonRpcResponse); @@ -422,12 +445,13 @@ private Flux directResponseFlux(McpSchema.JSONRPCMessa ClientResponse response) { return response.bodyToMono(String.class).>handle((responseMessage, s) -> { try { - if (sentMessage instanceof McpSchema.JSONRPCNotification && Utils.hasText(responseMessage)) { - logger.warn("Notification: {} received non-compliant response: {}", sentMessage, responseMessage); + if (sentMessage instanceof McpSchema.JSONRPCNotification) { + logger.warn("Notification: {} received non-compliant response: {}", sentMessage, + Utils.hasText(responseMessage) ? responseMessage : "[empty]"); s.complete(); } else { - McpSchema.JSONRPCMessage jsonRpcResponse = McpSchema.deserializeJsonRpcMessage(objectMapper, + McpSchema.JSONRPCMessage jsonRpcResponse = McpSchema.deserializeJsonRpcMessage(jsonMapper, responseMessage); s.next(List.of(jsonRpcResponse)); } @@ -447,8 +471,8 @@ private Flux newEventStream(ClientResponse response, S } @Override - public T unmarshalFrom(Object data, TypeReference typeRef) { - return this.objectMapper.convertValue(data, typeRef); + public T unmarshalFrom(Object data, TypeRef typeRef) { + return this.jsonMapper.convertValue(data, typeRef); } private Tuple2, Iterable> parse(ServerSentEvent event) { @@ -456,7 +480,7 @@ private Tuple2, Iterable> parse(Serve try { // We don't support batching ATM and probably won't since the next version // considers removing it. - McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.objectMapper, event.data()); + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.jsonMapper, event.data()); return Tuples.of(Optional.ofNullable(event.id()), List.of(message)); } catch (IOException ioException) { @@ -474,7 +498,7 @@ private Tuple2, Iterable> parse(Serve */ public static class Builder { - private ObjectMapper objectMapper; + private McpJsonMapper jsonMapper; private WebClient.Builder webClientBuilder; @@ -484,19 +508,22 @@ public static class Builder { private boolean openConnectionOnStartup = false; + private List supportedProtocolVersions = List.of(ProtocolVersions.MCP_2024_11_05, + ProtocolVersions.MCP_2025_03_26, ProtocolVersions.MCP_2025_06_18); + private Builder(WebClient.Builder webClientBuilder) { Assert.notNull(webClientBuilder, "WebClient.Builder must not be null"); this.webClientBuilder = webClientBuilder; } /** - * Configure the {@link ObjectMapper} to use. - * @param objectMapper instance to use + * Configure the {@link McpJsonMapper} to use. + * @param jsonMapper instance to use * @return the builder instance */ - public Builder objectMapper(ObjectMapper objectMapper) { - Assert.notNull(objectMapper, "ObjectMapper must not be null"); - this.objectMapper = objectMapper; + public Builder jsonMapper(McpJsonMapper jsonMapper) { + Assert.notNull(jsonMapper, "JsonMapper must not be null"); + this.jsonMapper = jsonMapper; return this; } @@ -549,16 +576,38 @@ public Builder openConnectionOnStartup(boolean openConnectionOnStartup) { return this; } + /** + * Sets the list of supported protocol versions used in version negotiation. By + * default, the client will send the latest of those versions in the + * {@code MCP-Protocol-Version} header. + *

+ * Setting this value only updates the values used in version negotiation, and + * does NOT impact the actual capabilities of the transport. It should only be + * used for compatibility with servers having strict requirements around the + * {@code MCP-Protocol-Version} header. + * @param supportedProtocolVersions protocol versions supported by this transport + * @return this builder + * @see version + * negotiation specification + * @see Protocol + * Version Header + */ + public Builder supportedProtocolVersions(List supportedProtocolVersions) { + Assert.notEmpty(supportedProtocolVersions, "supportedProtocolVersions must not be empty"); + this.supportedProtocolVersions = Collections.unmodifiableList(supportedProtocolVersions); + return this; + } + /** * Construct a fresh instance of {@link WebClientStreamableHttpTransport} using * the current builder configuration. * @return a new instance of {@link WebClientStreamableHttpTransport} */ public WebClientStreamableHttpTransport build() { - ObjectMapper objectMapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); - - return new WebClientStreamableHttpTransport(objectMapper, this.webClientBuilder, endpoint, resumableStreams, - openConnectionOnStartup); + return new WebClientStreamableHttpTransport(jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, + webClientBuilder, endpoint, resumableStreams, openConnectionOnStartup, supportedProtocolVersions); } } diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java index 51d21d18b..91b89d6d2 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java @@ -9,8 +9,8 @@ import java.util.function.BiConsumer; import java.util.function.Function; -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.TypeRef; import io.modelcontextprotocol.spec.HttpHeaders; import io.modelcontextprotocol.spec.McpClientTransport; @@ -100,10 +100,10 @@ public class WebFluxSseClientTransport implements McpClientTransport { private final WebClient webClient; /** - * ObjectMapper for serializing outbound messages and deserializing inbound messages. + * JSON mapper for serializing outbound messages and deserializing inbound messages. * Handles conversion between JSON-RPC messages and their string representation. */ - protected ObjectMapper objectMapper; + protected McpJsonMapper jsonMapper; /** * Subscription for the SSE connection handling inbound messages. Used for cleanup @@ -129,27 +129,16 @@ public class WebFluxSseClientTransport implements McpClientTransport { */ private String sseEndpoint; - /** - * Constructs a new SseClientTransport with the specified WebClient builder. Uses a - * default ObjectMapper instance for JSON processing. - * @param webClientBuilder the WebClient.Builder to use for creating the WebClient - * instance - * @throws IllegalArgumentException if webClientBuilder is null - */ - public WebFluxSseClientTransport(WebClient.Builder webClientBuilder) { - this(webClientBuilder, new ObjectMapper()); - } - /** * Constructs a new SseClientTransport with the specified WebClient builder and * ObjectMapper. Initializes both inbound and outbound message processing pipelines. * @param webClientBuilder the WebClient.Builder to use for creating the WebClient * instance - * @param objectMapper the ObjectMapper to use for JSON processing + * @param jsonMapper the ObjectMapper to use for JSON processing * @throws IllegalArgumentException if either parameter is null */ - public WebFluxSseClientTransport(WebClient.Builder webClientBuilder, ObjectMapper objectMapper) { - this(webClientBuilder, objectMapper, DEFAULT_SSE_ENDPOINT); + public WebFluxSseClientTransport(WebClient.Builder webClientBuilder, McpJsonMapper jsonMapper) { + this(webClientBuilder, jsonMapper, DEFAULT_SSE_ENDPOINT); } /** @@ -157,17 +146,16 @@ public WebFluxSseClientTransport(WebClient.Builder webClientBuilder, ObjectMappe * ObjectMapper. Initializes both inbound and outbound message processing pipelines. * @param webClientBuilder the WebClient.Builder to use for creating the WebClient * instance - * @param objectMapper the ObjectMapper to use for JSON processing + * @param jsonMapper the ObjectMapper to use for JSON processing * @param sseEndpoint the SSE endpoint URI to use for establishing the connection * @throws IllegalArgumentException if either parameter is null */ - public WebFluxSseClientTransport(WebClient.Builder webClientBuilder, ObjectMapper objectMapper, - String sseEndpoint) { - Assert.notNull(objectMapper, "ObjectMapper must not be null"); + public WebFluxSseClientTransport(WebClient.Builder webClientBuilder, McpJsonMapper jsonMapper, String sseEndpoint) { + Assert.notNull(jsonMapper, "jsonMapper must not be null"); Assert.notNull(webClientBuilder, "WebClient.Builder must not be null"); Assert.hasText(sseEndpoint, "SSE endpoint must not be null or empty"); - this.objectMapper = objectMapper; + this.jsonMapper = jsonMapper; this.webClient = webClientBuilder.build(); this.sseEndpoint = sseEndpoint; } @@ -217,7 +205,7 @@ public Mono connect(Function, Mono> h } else if (MESSAGE_EVENT_TYPE.equals(event.event())) { try { - JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.objectMapper, event.data()); + JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.jsonMapper, event.data()); s.next(message); } catch (IOException ioException) { @@ -255,7 +243,7 @@ public Mono sendMessage(JSONRPCMessage message) { return Mono.empty(); } try { - String jsonText = this.objectMapper.writeValueAsString(message); + String jsonText = this.jsonMapper.writeValueAsString(message); return webClient.post() .uri(messageEndpointUri) .contentType(MediaType.APPLICATION_JSON) @@ -349,13 +337,13 @@ public Mono closeGracefully() { // @formatter:off * type conversion capabilities to handle complex object structures. * @param the target type to convert the data into * @param data the source object to convert - * @param typeRef the TypeReference describing the target type + * @param typeRef the TypeRef describing the target type * @return the unmarshalled object of type T * @throws IllegalArgumentException if the conversion cannot be performed */ @Override - public T unmarshalFrom(Object data, TypeReference typeRef) { - return this.objectMapper.convertValue(data, typeRef); + public T unmarshalFrom(Object data, TypeRef typeRef) { + return this.jsonMapper.convertValue(data, typeRef); } /** @@ -377,7 +365,7 @@ public static class Builder { private String sseEndpoint = DEFAULT_SSE_ENDPOINT; - private ObjectMapper objectMapper = new ObjectMapper(); + private McpJsonMapper jsonMapper; /** * Creates a new builder with the specified WebClient.Builder. @@ -400,13 +388,13 @@ public Builder sseEndpoint(String sseEndpoint) { } /** - * Sets the object mapper for JSON serialization/deserialization. - * @param objectMapper the object mapper + * Sets the JSON mapper for serialization/deserialization. + * @param jsonMapper the JsonMapper to use * @return this builder */ - public Builder objectMapper(ObjectMapper objectMapper) { - Assert.notNull(objectMapper, "objectMapper must not be null"); - this.objectMapper = objectMapper; + public Builder jsonMapper(McpJsonMapper jsonMapper) { + Assert.notNull(jsonMapper, "jsonMapper must not be null"); + this.jsonMapper = jsonMapper; return this; } @@ -415,7 +403,8 @@ public Builder objectMapper(ObjectMapper objectMapper) { * @return a new transport instance */ public WebFluxSseClientTransport build() { - return new WebFluxSseClientTransport(webClientBuilder, objectMapper, sseEndpoint); + return new WebFluxSseClientTransport(webClientBuilder, + jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, sseEndpoint); } } diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java index aaf7bab46..0c80c5b8b 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java @@ -9,8 +9,10 @@ import java.util.List; import java.util.concurrent.ConcurrentHashMap; -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.TypeRef; +import io.modelcontextprotocol.server.McpTransportContextExtractor; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpServerSession; @@ -19,7 +21,6 @@ import io.modelcontextprotocol.spec.ProtocolVersions; import io.modelcontextprotocol.util.Assert; import io.modelcontextprotocol.util.KeepAliveScheduler; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.Exceptions; @@ -34,6 +35,7 @@ import org.springframework.web.reactive.function.server.RouterFunctions; import org.springframework.web.reactive.function.server.ServerRequest; import org.springframework.web.reactive.function.server.ServerResponse; +import org.springframework.web.util.UriComponentsBuilder; /** * Server-side implementation of the MCP (Model Context Protocol) HTTP transport using @@ -92,9 +94,11 @@ public class WebFluxSseServerTransportProvider implements McpServerTransportProv */ public static final String DEFAULT_SSE_ENDPOINT = "/sse"; + public static final String SESSION_ID = "sessionId"; + public static final String DEFAULT_BASE_URL = ""; - private final ObjectMapper objectMapper; + private final McpJsonMapper jsonMapper; /** * Base URL for the message endpoint. This is used to construct the full URL for @@ -115,6 +119,8 @@ public class WebFluxSseServerTransportProvider implements McpServerTransportProv */ private final ConcurrentHashMap sessions = new ConcurrentHashMap<>(); + private McpTransportContextExtractor contextExtractor; + /** * Flag indicating if the transport is shutting down. */ @@ -126,83 +132,34 @@ public class WebFluxSseServerTransportProvider implements McpServerTransportProv */ private KeepAliveScheduler keepAliveScheduler; - /** - * Constructs a new WebFlux SSE server transport provider instance with the default - * SSE endpoint. - * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization - * of MCP messages. Must not be null. - * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC - * messages. This endpoint will be communicated to clients during SSE connection - * setup. Must not be null. - * @throws IllegalArgumentException if either parameter is null - * @deprecated Use the builder {@link #builder()} instead for better configuration - * options. - */ - @Deprecated - public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint) { - this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT); - } - /** * Constructs a new WebFlux SSE server transport provider instance. - * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization - * of MCP messages. Must not be null. - * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC - * messages. This endpoint will be communicated to clients during SSE connection - * setup. Must not be null. - * @throws IllegalArgumentException if either parameter is null - * @deprecated Use the builder {@link #builder()} instead for better configuration - * options. - */ - @Deprecated - public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) { - this(objectMapper, DEFAULT_BASE_URL, messageEndpoint, sseEndpoint); - } - - /** - * Constructs a new WebFlux SSE server transport provider instance. - * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization - * of MCP messages. Must not be null. - * @param baseUrl webflux message base path - * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC - * messages. This endpoint will be communicated to clients during SSE connection - * setup. Must not be null. - * @throws IllegalArgumentException if either parameter is null - * @deprecated Use the builder {@link #builder()} instead for better configuration - * options. - */ - @Deprecated - public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, - String sseEndpoint) { - this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, null); - } - - /** - * Constructs a new WebFlux SSE server transport provider instance. - * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization - * of MCP messages. Must not be null. + * @param jsonMapper The ObjectMapper to use for JSON serialization/deserialization of + * MCP messages. Must not be null. * @param baseUrl webflux message base path * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC * messages. This endpoint will be communicated to clients during SSE connection * setup. Must not be null. * @param sseEndpoint The SSE endpoint path. Must not be null. * @param keepAliveInterval The interval for sending keep-alive pings to clients. + * @param contextExtractor The context extractor to use for extracting MCP transport + * context from HTTP requests. Must not be null. * @throws IllegalArgumentException if either parameter is null - * @deprecated Use the builder {@link #builder()} instead for better configuration - * options. */ - @Deprecated - public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, - String sseEndpoint, Duration keepAliveInterval) { - Assert.notNull(objectMapper, "ObjectMapper must not be null"); + private WebFluxSseServerTransportProvider(McpJsonMapper jsonMapper, String baseUrl, String messageEndpoint, + String sseEndpoint, Duration keepAliveInterval, + McpTransportContextExtractor contextExtractor) { + Assert.notNull(jsonMapper, "ObjectMapper must not be null"); Assert.notNull(baseUrl, "Message base path must not be null"); Assert.notNull(messageEndpoint, "Message endpoint must not be null"); Assert.notNull(sseEndpoint, "SSE endpoint must not be null"); + Assert.notNull(contextExtractor, "Context extractor must not be null"); - this.objectMapper = objectMapper; + this.jsonMapper = jsonMapper; this.baseUrl = baseUrl; this.messageEndpoint = messageEndpoint; this.sseEndpoint = sseEndpoint; + this.contextExtractor = contextExtractor; this.routerFunction = RouterFunctions.route() .GET(this.sseEndpoint, this::handleSseConnection) .POST(this.messageEndpoint, this::handleMessage) @@ -268,6 +225,7 @@ public Mono notifyClients(String method, Object params) { // FIXME: This javadoc makes claims about using isClosing flag but it's not // actually // doing that. + /** * Initiates a graceful shutdown of all the sessions. This method ensures all active * sessions are properly closed and cleaned up. @@ -315,6 +273,8 @@ private Mono handleSseConnection(ServerRequest request) { return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down"); } + McpTransportContext transportContext = this.contextExtractor.extract(request); + return ServerResponse.ok() .contentType(MediaType.TEXT_EVENT_STREAM) .body(Flux.>create(sink -> { @@ -328,15 +288,28 @@ private Mono handleSseConnection(ServerRequest request) { // Send initial endpoint event logger.debug("Sending initial endpoint event to session: {}", sessionId); - sink.next(ServerSentEvent.builder() - .event(ENDPOINT_EVENT_TYPE) - .data(this.baseUrl + this.messageEndpoint + "?sessionId=" + sessionId) - .build()); + sink.next( + ServerSentEvent.builder().event(ENDPOINT_EVENT_TYPE).data(buildEndpointUrl(sessionId)).build()); sink.onCancel(() -> { logger.debug("Session {} cancelled", sessionId); sessions.remove(sessionId); }); - }), ServerSentEvent.class); + }).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)), ServerSentEvent.class); + } + + /** + * Constructs the full message endpoint URL by combining the base URL, message path, + * and the required session_id query parameter. + * @param sessionId the unique session identifier + * @return the fully qualified endpoint URL as a string + */ + private String buildEndpointUrl(String sessionId) { + // for WebMVC compatibility + return UriComponentsBuilder.fromUriString(this.baseUrl) + .path(this.messageEndpoint) + .queryParam(SESSION_ID, sessionId) + .build() + .toUriString(); } /** @@ -370,9 +343,11 @@ private Mono handleMessage(ServerRequest request) { .bodyValue(new McpError("Session not found: " + request.queryParam("sessionId").get())); } + McpTransportContext transportContext = this.contextExtractor.extract(request); + return request.bodyToMono(String.class).flatMap(body -> { try { - McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body); + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(jsonMapper, body); return session.handle(message).flatMap(response -> ServerResponse.ok().build()).onErrorResume(error -> { logger.error("Error processing message: {}", error.getMessage()); // TODO: instead of signalling the error, just respond with 200 OK @@ -386,7 +361,7 @@ private Mono handleMessage(ServerRequest request) { logger.error("Failed to deserialize message: {}", e.getMessage()); return ServerResponse.badRequest().bodyValue(new McpError("Invalid message format")); } - }); + }).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)); } private class WebFluxMcpSessionTransport implements McpServerTransport { @@ -401,7 +376,7 @@ public WebFluxMcpSessionTransport(FluxSink> sink) { public Mono sendMessage(McpSchema.JSONRPCMessage message) { return Mono.fromSupplier(() -> { try { - return objectMapper.writeValueAsString(message); + return jsonMapper.writeValueAsString(message); } catch (IOException e) { throw Exceptions.propagate(e); @@ -420,8 +395,8 @@ public Mono sendMessage(McpSchema.JSONRPCMessage message) { } @Override - public T unmarshalFrom(Object data, TypeReference typeRef) { - return objectMapper.convertValue(data, typeRef); + public T unmarshalFrom(Object data, TypeRef typeRef) { + return jsonMapper.convertValue(data, typeRef); } @Override @@ -448,7 +423,7 @@ public static Builder builder() { */ public static class Builder { - private ObjectMapper objectMapper; + private McpJsonMapper jsonMapper; private String baseUrl = DEFAULT_BASE_URL; @@ -458,16 +433,19 @@ public static class Builder { private Duration keepAliveInterval; + private McpTransportContextExtractor contextExtractor = ( + serverRequest) -> McpTransportContext.EMPTY; + /** - * Sets the ObjectMapper to use for JSON serialization/deserialization of MCP + * Sets the McpJsonMapper to use for JSON serialization/deserialization of MCP * messages. - * @param objectMapper The ObjectMapper instance. Must not be null. + * @param jsonMapper The McpJsonMapper instance. Must not be null. * @return this builder instance - * @throws IllegalArgumentException if objectMapper is null + * @throws IllegalArgumentException if jsonMapper is null */ - public Builder objectMapper(ObjectMapper objectMapper) { - Assert.notNull(objectMapper, "ObjectMapper must not be null"); - this.objectMapper = objectMapper; + public Builder jsonMapper(McpJsonMapper jsonMapper) { + Assert.notNull(jsonMapper, "JsonMapper must not be null"); + this.jsonMapper = jsonMapper; return this; } @@ -519,6 +497,22 @@ public Builder keepAliveInterval(Duration keepAliveInterval) { return this; } + /** + * Sets the context extractor that allows providing the MCP feature + * implementations to inspect HTTP transport level metadata that was present at + * HTTP request processing time. This allows to extract custom headers and other + * useful data for use during execution later on in the process. + * @param contextExtractor The contextExtractor to fill in a + * {@link McpTransportContext}. + * @return this builder instance + * @throws IllegalArgumentException if contextExtractor is null + */ + public Builder contextExtractor(McpTransportContextExtractor contextExtractor) { + Assert.notNull(contextExtractor, "contextExtractor must not be null"); + this.contextExtractor = contextExtractor; + return this; + } + /** * Builds a new instance of {@link WebFluxSseServerTransportProvider} with the * configured settings. @@ -526,11 +520,9 @@ public Builder keepAliveInterval(Duration keepAliveInterval) { * @throws IllegalStateException if required parameters are not set */ public WebFluxSseServerTransportProvider build() { - Assert.notNull(objectMapper, "ObjectMapper must be set"); Assert.notNull(messageEndpoint, "Message endpoint must be set"); - - return new WebFluxSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint, - keepAliveInterval); + return new WebFluxSseServerTransportProvider(jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, + baseUrl, messageEndpoint, sseEndpoint, keepAliveInterval, contextExtractor); } } diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStatelessServerTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStatelessServerTransport.java index 23fff25b3..400be341e 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStatelessServerTransport.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStatelessServerTransport.java @@ -4,14 +4,13 @@ package io.modelcontextprotocol.server.transport; -import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.common.McpTransportContext; import io.modelcontextprotocol.server.McpStatelessServerHandler; -import io.modelcontextprotocol.server.DefaultMcpTransportContext; import io.modelcontextprotocol.server.McpTransportContextExtractor; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpStatelessServerTransport; -import io.modelcontextprotocol.server.McpTransportContext; import io.modelcontextprotocol.util.Assert; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -35,7 +34,7 @@ public class WebFluxStatelessServerTransport implements McpStatelessServerTransp private static final Logger logger = LoggerFactory.getLogger(WebFluxStatelessServerTransport.class); - private final ObjectMapper objectMapper; + private final McpJsonMapper jsonMapper; private final String mcpEndpoint; @@ -47,13 +46,13 @@ public class WebFluxStatelessServerTransport implements McpStatelessServerTransp private volatile boolean isClosing = false; - private WebFluxStatelessServerTransport(ObjectMapper objectMapper, String mcpEndpoint, + private WebFluxStatelessServerTransport(McpJsonMapper jsonMapper, String mcpEndpoint, McpTransportContextExtractor contextExtractor) { - Assert.notNull(objectMapper, "objectMapper must not be null"); + Assert.notNull(jsonMapper, "jsonMapper must not be null"); Assert.notNull(mcpEndpoint, "mcpEndpoint must not be null"); Assert.notNull(contextExtractor, "contextExtractor must not be null"); - this.objectMapper = objectMapper; + this.jsonMapper = jsonMapper; this.mcpEndpoint = mcpEndpoint; this.contextExtractor = contextExtractor; this.routerFunction = RouterFunctions.route() @@ -97,7 +96,7 @@ private Mono handlePost(ServerRequest request) { return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down"); } - McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext()); + McpTransportContext transportContext = this.contextExtractor.extract(request); List acceptHeaders = request.headers().asHttpHeaders().getAccept(); if (!(acceptHeaders.contains(MediaType.APPLICATION_JSON) @@ -107,13 +106,20 @@ private Mono handlePost(ServerRequest request) { return request.bodyToMono(String.class).flatMap(body -> { try { - McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body); + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(jsonMapper, body); if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest) { - return this.mcpHandler.handleRequest(transportContext, jsonrpcRequest) - .flatMap(jsonrpcResponse -> ServerResponse.ok() - .contentType(MediaType.APPLICATION_JSON) - .bodyValue(jsonrpcResponse)); + return this.mcpHandler.handleRequest(transportContext, jsonrpcRequest).flatMap(jsonrpcResponse -> { + try { + String json = jsonMapper.writeValueAsString(jsonrpcResponse); + return ServerResponse.ok().contentType(MediaType.APPLICATION_JSON).bodyValue(json); + } + catch (IOException e) { + logger.error("Failed to serialize response: {}", e.getMessage()); + return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR) + .bodyValue(new McpError("Failed to serialize response")); + } + }); } else if (message instanceof McpSchema.JSONRPCNotification jsonrpcNotification) { return this.mcpHandler.handleNotification(transportContext, jsonrpcNotification) @@ -147,26 +153,27 @@ public static Builder builder() { */ public static class Builder { - private ObjectMapper objectMapper; + private McpJsonMapper jsonMapper; private String mcpEndpoint = "/mcp"; - private McpTransportContextExtractor contextExtractor = (serverRequest, context) -> context; + private McpTransportContextExtractor contextExtractor = ( + serverRequest) -> McpTransportContext.EMPTY; private Builder() { // used by a static method } /** - * Sets the ObjectMapper to use for JSON serialization/deserialization of MCP + * Sets the JsonMapper to use for JSON serialization/deserialization of MCP * messages. - * @param objectMapper The ObjectMapper instance. Must not be null. + * @param jsonMapper The JsonMapper instance. Must not be null. * @return this builder instance - * @throws IllegalArgumentException if objectMapper is null + * @throws IllegalArgumentException if jsonMapper is null */ - public Builder objectMapper(ObjectMapper objectMapper) { - Assert.notNull(objectMapper, "ObjectMapper must not be null"); - this.objectMapper = objectMapper; + public Builder jsonMapper(McpJsonMapper jsonMapper) { + Assert.notNull(jsonMapper, "JsonMapper must not be null"); + this.jsonMapper = jsonMapper; return this; } @@ -205,10 +212,9 @@ public Builder contextExtractor(McpTransportContextExtractor cont * @throws IllegalStateException if required parameters are not set */ public WebFluxStatelessServerTransport build() { - Assert.notNull(objectMapper, "ObjectMapper must be set"); Assert.notNull(mcpEndpoint, "Message endpoint must be set"); - - return new WebFluxStatelessServerTransport(objectMapper, mcpEndpoint, contextExtractor); + return new WebFluxStatelessServerTransport(jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, + mcpEndpoint, contextExtractor); } } diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java index f3f6c2c33..deebfc616 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java @@ -4,9 +4,9 @@ package io.modelcontextprotocol.server.transport; -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.server.DefaultMcpTransportContext; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.TypeRef; +import io.modelcontextprotocol.common.McpTransportContext; import io.modelcontextprotocol.server.McpTransportContextExtractor; import io.modelcontextprotocol.spec.HttpHeaders; import io.modelcontextprotocol.spec.McpError; @@ -15,7 +15,6 @@ import io.modelcontextprotocol.spec.McpStreamableServerTransport; import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; import io.modelcontextprotocol.spec.ProtocolVersions; -import io.modelcontextprotocol.server.McpTransportContext; import io.modelcontextprotocol.util.Assert; import io.modelcontextprotocol.util.KeepAliveScheduler; @@ -50,7 +49,7 @@ public class WebFluxStreamableServerTransportProvider implements McpStreamableSe public static final String MESSAGE_EVENT_TYPE = "message"; - private final ObjectMapper objectMapper; + private final McpJsonMapper jsonMapper; private final String mcpEndpoint; @@ -68,14 +67,14 @@ public class WebFluxStreamableServerTransportProvider implements McpStreamableSe private KeepAliveScheduler keepAliveScheduler; - private WebFluxStreamableServerTransportProvider(ObjectMapper objectMapper, String mcpEndpoint, + private WebFluxStreamableServerTransportProvider(McpJsonMapper jsonMapper, String mcpEndpoint, McpTransportContextExtractor contextExtractor, boolean disallowDelete, Duration keepAliveInterval) { - Assert.notNull(objectMapper, "ObjectMapper must not be null"); + Assert.notNull(jsonMapper, "JsonMapper must not be null"); Assert.notNull(mcpEndpoint, "Message endpoint must not be null"); Assert.notNull(contextExtractor, "Context extractor must not be null"); - this.objectMapper = objectMapper; + this.jsonMapper = jsonMapper; this.mcpEndpoint = mcpEndpoint; this.contextExtractor = contextExtractor; this.disallowDelete = disallowDelete; @@ -98,7 +97,8 @@ private WebFluxStreamableServerTransportProvider(ObjectMapper objectMapper, Stri @Override public List protocolVersions() { - return List.of(ProtocolVersions.MCP_2024_11_05, ProtocolVersions.MCP_2025_03_26); + return List.of(ProtocolVersions.MCP_2024_11_05, ProtocolVersions.MCP_2025_03_26, + ProtocolVersions.MCP_2025_06_18); } @Override @@ -166,7 +166,7 @@ private Mono handleGet(ServerRequest request) { return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down"); } - McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext()); + McpTransportContext transportContext = this.contextExtractor.extract(request); return Mono.defer(() -> { List acceptHeaders = request.headers().asHttpHeaders().getAccept(); @@ -174,7 +174,7 @@ private Mono handleGet(ServerRequest request) { return ServerResponse.badRequest().build(); } - if (!request.headers().asHttpHeaders().containsKey(HttpHeaders.MCP_SESSION_ID)) { + if (request.headers().header(HttpHeaders.MCP_SESSION_ID).isEmpty()) { return ServerResponse.badRequest().build(); // TODO: say we need a session // id } @@ -187,11 +187,13 @@ private Mono handleGet(ServerRequest request) { return ServerResponse.notFound().build(); } - if (request.headers().asHttpHeaders().containsKey(HttpHeaders.LAST_EVENT_ID)) { + if (!request.headers().header(HttpHeaders.LAST_EVENT_ID).isEmpty()) { String lastId = request.headers().asHttpHeaders().getFirst(HttpHeaders.LAST_EVENT_ID); return ServerResponse.ok() .contentType(MediaType.TEXT_EVENT_STREAM) - .body(session.replay(lastId), ServerSentEvent.class); + .body(session.replay(lastId) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)), + ServerSentEvent.class); } return ServerResponse.ok() @@ -202,7 +204,9 @@ private Mono handleGet(ServerRequest request) { McpStreamableServerSession.McpStreamableServerSessionStream listeningStream = session .listeningStream(sessionTransport); sink.onDispose(listeningStream::close); - }), ServerSentEvent.class); + // TODO Clarify why the outer context is not present in the + // Flux.create sink? + }).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)), ServerSentEvent.class); }).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)); } @@ -217,7 +221,7 @@ private Mono handlePost(ServerRequest request) { return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down"); } - McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext()); + McpTransportContext transportContext = this.contextExtractor.extract(request); List acceptHeaders = request.headers().asHttpHeaders().getAccept(); if (!(acceptHeaders.contains(MediaType.APPLICATION_JSON) @@ -227,12 +231,13 @@ private Mono handlePost(ServerRequest request) { return request.bodyToMono(String.class).flatMap(body -> { try { - McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body); + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(jsonMapper, body); if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest && jsonrpcRequest.method().equals(McpSchema.METHOD_INITIALIZE)) { - McpSchema.InitializeRequest initializeRequest = objectMapper.convertValue(jsonrpcRequest.params(), - new TypeReference() { - }); + var typeReference = new TypeRef() { + }; + McpSchema.InitializeRequest initializeRequest = jsonMapper.convertValue(jsonrpcRequest.params(), + typeReference); McpStreamableServerSession.McpStreamableServerSessionInit init = this.sessionFactory .startSession(initializeRequest); sessions.put(init.session().getId(), init.session()); @@ -240,7 +245,7 @@ private Mono handlePost(ServerRequest request) { McpSchema.JSONRPCResponse jsonrpcResponse = new McpSchema.JSONRPCResponse( McpSchema.JSONRPC_VERSION, jsonrpcRequest.id(), initializeResult, null); try { - return this.objectMapper.writeValueAsString(jsonrpcResponse); + return this.jsonMapper.writeValueAsString(jsonrpcResponse); } catch (IOException e) { logger.warn("Failed to serialize initResponse", e); @@ -253,7 +258,7 @@ private Mono handlePost(ServerRequest request) { .bodyValue(initResult)); } - if (!request.headers().asHttpHeaders().containsKey(HttpHeaders.MCP_SESSION_ID)) { + if (request.headers().header(HttpHeaders.MCP_SESSION_ID).isEmpty()) { return ServerResponse.badRequest().bodyValue(new McpError("Session ID missing")); } @@ -282,7 +287,10 @@ else if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest) { return true; }).contextWrite(sink.contextView()).subscribe(); sink.onCancel(streamSubscription); - }), ServerSentEvent.class); + // TODO Clarify why the outer context is not present in the + // Flux.create sink? + }).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)), + ServerSentEvent.class); } else { return ServerResponse.badRequest().bodyValue(new McpError("Unknown message type")); @@ -302,10 +310,10 @@ private Mono handleDelete(ServerRequest request) { return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down"); } - McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext()); + McpTransportContext transportContext = this.contextExtractor.extract(request); return Mono.defer(() -> { - if (!request.headers().asHttpHeaders().containsKey(HttpHeaders.MCP_SESSION_ID)) { + if (request.headers().header(HttpHeaders.MCP_SESSION_ID).isEmpty()) { return ServerResponse.badRequest().build(); // TODO: say we need a session // id } @@ -343,7 +351,7 @@ public Mono sendMessage(McpSchema.JSONRPCMessage message) { public Mono sendMessage(McpSchema.JSONRPCMessage message, String messageId) { return Mono.fromSupplier(() -> { try { - return objectMapper.writeValueAsString(message); + return jsonMapper.writeValueAsString(message); } catch (IOException e) { throw Exceptions.propagate(e); @@ -363,8 +371,8 @@ public Mono sendMessage(McpSchema.JSONRPCMessage message, String messageId } @Override - public T unmarshalFrom(Object data, TypeReference typeRef) { - return objectMapper.convertValue(data, typeRef); + public T unmarshalFrom(Object data, TypeRef typeRef) { + return jsonMapper.convertValue(data, typeRef); } @Override @@ -391,11 +399,12 @@ public static Builder builder() { */ public static class Builder { - private ObjectMapper objectMapper; + private McpJsonMapper jsonMapper; private String mcpEndpoint = "/mcp"; - private McpTransportContextExtractor contextExtractor = (serverRequest, context) -> context; + private McpTransportContextExtractor contextExtractor = ( + serverRequest) -> McpTransportContext.EMPTY; private boolean disallowDelete; @@ -406,15 +415,15 @@ private Builder() { } /** - * Sets the ObjectMapper to use for JSON serialization/deserialization of MCP - * messages. - * @param objectMapper The ObjectMapper instance. Must not be null. + * Sets the {@link McpJsonMapper} to use for JSON serialization/deserialization of + * MCP messages. + * @param jsonMapper The {@link McpJsonMapper} instance. Must not be null. * @return this builder instance - * @throws IllegalArgumentException if objectMapper is null + * @throws IllegalArgumentException if jsonMapper is null */ - public Builder objectMapper(ObjectMapper objectMapper) { - Assert.notNull(objectMapper, "ObjectMapper must not be null"); - this.objectMapper = objectMapper; + public Builder jsonMapper(McpJsonMapper jsonMapper) { + Assert.notNull(jsonMapper, "McpJsonMapper must not be null"); + this.jsonMapper = jsonMapper; return this; } @@ -475,13 +484,12 @@ public Builder keepAliveInterval(Duration keepAliveInterval) { * @throws IllegalStateException if required parameters are not set */ public WebFluxStreamableServerTransportProvider build() { - Assert.notNull(objectMapper, "ObjectMapper must be set"); Assert.notNull(mcpEndpoint, "Message endpoint must be set"); - - return new WebFluxStreamableServerTransportProvider(objectMapper, mcpEndpoint, contextExtractor, + return new WebFluxStreamableServerTransportProvider( + jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, mcpEndpoint, contextExtractor, disallowDelete, keepAliveInterval); } } -} \ No newline at end of file +} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java index 6140fe489..eb8abb90c 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java @@ -5,23 +5,28 @@ package io.modelcontextprotocol; import java.time.Duration; +import java.util.Map; +import java.util.stream.Stream; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.provider.Arguments; + import org.springframework.http.server.reactive.HttpHandler; import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; import org.springframework.web.reactive.function.client.WebClient; import org.springframework.web.reactive.function.server.RouterFunctions; - -import com.fasterxml.jackson.databind.ObjectMapper; +import org.springframework.web.reactive.function.server.ServerRequest; import io.modelcontextprotocol.client.McpClient; import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; +import io.modelcontextprotocol.common.McpTransportContext; import io.modelcontextprotocol.server.McpServer; import io.modelcontextprotocol.server.McpServer.AsyncSpecification; import io.modelcontextprotocol.server.McpServer.SingleSessionSyncSpecification; +import io.modelcontextprotocol.server.McpTransportContextExtractor; import io.modelcontextprotocol.server.TestUtil; import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider; import reactor.netty.DisposableServer; @@ -40,6 +45,13 @@ class WebFluxSseIntegrationTests extends AbstractMcpClientServerIntegrationTests private WebFluxSseServerTransportProvider mcpServerTransportProvider; + static McpTransportContextExtractor TEST_CONTEXT_EXTRACTOR = (r) -> McpTransportContext + .create(Map.of("important", "value")); + + static Stream clientsForTesting() { + return Stream.of(Arguments.of("httpclient"), Arguments.of("webflux")); + } + @Override protected void prepareClients(int port, String mcpEndpoint) { @@ -72,9 +84,9 @@ protected SingleSessionSyncSpecification prepareSyncServerBuilder() { public void before() { this.mcpServerTransportProvider = new WebFluxSseServerTransportProvider.Builder() - .objectMapper(new ObjectMapper()) .messageEndpoint(CUSTOM_MESSAGE_ENDPOINT) .sseEndpoint(CUSTOM_SSE_ENDPOINT) + .contextExtractor(TEST_CONTEXT_EXTRACTOR) .build(); HttpHandler httpHandler = RouterFunctions.toHttpHandler(mcpServerTransportProvider.getRouterFunction()); diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStatelessIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStatelessIntegrationTests.java index 5516e55b7..96a786a9e 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStatelessIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStatelessIntegrationTests.java @@ -5,17 +5,17 @@ package io.modelcontextprotocol; import java.time.Duration; +import java.util.stream.Stream; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.provider.Arguments; + import org.springframework.http.server.reactive.HttpHandler; import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; import org.springframework.web.reactive.function.client.WebClient; import org.springframework.web.reactive.function.server.RouterFunctions; - -import com.fasterxml.jackson.databind.ObjectMapper; - import io.modelcontextprotocol.client.McpClient; import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; @@ -38,6 +38,10 @@ class WebFluxStatelessIntegrationTests extends AbstractStatelessIntegrationTests private WebFluxStatelessServerTransport mcpStreamableServerTransport; + static Stream clientsForTesting() { + return Stream.of(Arguments.of("httpclient"), Arguments.of("webflux")); + } + @Override protected void prepareClients(int port, String mcpEndpoint) { clientBuilders @@ -67,7 +71,6 @@ protected StatelessSyncSpecification prepareSyncServerBuilder() { @BeforeEach public void before() { this.mcpStreamableServerTransport = WebFluxStatelessServerTransport.builder() - .objectMapper(new ObjectMapper()) .messageEndpoint(CUSTOM_MESSAGE_ENDPOINT) .build(); diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStreamableHttpVersionNegotiationIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStreamableHttpVersionNegotiationIntegrationTests.java new file mode 100644 index 000000000..5d2bfda68 --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStreamableHttpVersionNegotiationIntegrationTests.java @@ -0,0 +1,151 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package io.modelcontextprotocol; + +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.function.BiFunction; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.McpServerFeatures; +import io.modelcontextprotocol.server.McpSyncServer; +import io.modelcontextprotocol.server.McpSyncServerExchange; +import io.modelcontextprotocol.server.TestUtil; +import io.modelcontextprotocol.server.transport.WebFluxStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.ProtocolVersions; +import io.modelcontextprotocol.utils.McpTestRequestRecordingExchangeFilterFunction; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import reactor.netty.DisposableServer; +import reactor.netty.http.server.HttpServer; + +import org.springframework.http.HttpMethod; +import org.springframework.http.server.reactive.HttpHandler; +import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; +import org.springframework.web.reactive.function.client.WebClient; +import org.springframework.web.reactive.function.server.RouterFunction; +import org.springframework.web.reactive.function.server.RouterFunctions; +import org.springframework.web.reactive.function.server.ServerResponse; + +import static org.assertj.core.api.Assertions.assertThat; + +class WebFluxStreamableHttpVersionNegotiationIntegrationTests { + + private static final int PORT = TestUtil.findAvailablePort(); + + private DisposableServer httpServer; + + private final McpTestRequestRecordingExchangeFilterFunction recordingFilterFunction = new McpTestRequestRecordingExchangeFilterFunction(); + + private final McpSchema.Tool toolSpec = McpSchema.Tool.builder() + .name("test-tool") + .description("return the protocol version used") + .build(); + + private final BiFunction toolHandler = ( + exchange, request) -> new McpSchema.CallToolResult( + exchange.transportContext().get("protocol-version").toString(), null); + + private final WebFluxStreamableServerTransportProvider mcpStreamableServerTransportProvider = WebFluxStreamableServerTransportProvider + .builder() + .contextExtractor(req -> McpTransportContext + .create(Map.of("protocol-version", req.headers().firstHeader("MCP-protocol-version")))) + .build(); + + private final McpSyncServer mcpServer = McpServer.sync(mcpStreamableServerTransportProvider) + .capabilities(McpSchema.ServerCapabilities.builder().tools(false).build()) + .tools(new McpServerFeatures.SyncToolSpecification(toolSpec, null, toolHandler)) + .build(); + + @BeforeEach + void setUp() { + RouterFunction filteredRouter = mcpStreamableServerTransportProvider.getRouterFunction() + .filter(recordingFilterFunction); + + HttpHandler httpHandler = RouterFunctions.toHttpHandler(filteredRouter); + + ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); + + this.httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); + } + + @AfterEach + public void after() { + if (httpServer != null) { + httpServer.disposeNow(); + } + if (mcpServer != null) { + mcpServer.close(); + } + } + + @Test + void usesLatestVersion() { + var client = McpClient + .sync(WebClientStreamableHttpTransport.builder(WebClient.builder().baseUrl("http://localhost:" + PORT)) + .build()) + .requestTimeout(Duration.ofHours(10)) + .build(); + + client.initialize(); + + McpSchema.CallToolResult response = client.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())); + + var calls = recordingFilterFunction.getCalls(); + assertThat(calls).filteredOn(c -> !c.body().contains("\"method\":\"initialize\"")) + // GET /mcp ; POST notification/initialized ; POST tools/call + .hasSize(3) + .map(McpTestRequestRecordingExchangeFilterFunction.Call::headers) + .allSatisfy(headers -> assertThat(headers).containsEntry("mcp-protocol-version", + ProtocolVersions.MCP_2025_06_18)); + + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo(ProtocolVersions.MCP_2025_06_18); + mcpServer.close(); + } + + @Test + void usesServerSupportedVersion() { + var transport = WebClientStreamableHttpTransport + .builder(WebClient.builder().baseUrl("http://localhost:" + PORT)) + .supportedProtocolVersions(List.of(ProtocolVersions.MCP_2025_06_18, "2263-03-18")) + .build(); + var client = McpClient.sync(transport).requestTimeout(Duration.ofHours(10)).build(); + + client.initialize(); + + McpSchema.CallToolResult response = client.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())); + + var calls = recordingFilterFunction.getCalls(); + // Initialize tells the server the Client's latest supported version + // FIXME: Set the correct protocol version on GET /mcp + assertThat(calls) + .filteredOn(c -> !c.body().contains("\"method\":\"initialize\"") && c.method().equals(HttpMethod.POST)) + // POST notification/initialized ; POST tools/call + .hasSize(2) + .map(McpTestRequestRecordingExchangeFilterFunction.Call::headers) + .allSatisfy(headers -> assertThat(headers).containsEntry("mcp-protocol-version", + ProtocolVersions.MCP_2025_06_18)); + + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo(ProtocolVersions.MCP_2025_06_18); + mcpServer.close(); + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStreamableIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStreamableIntegrationTests.java index 9eba0e57c..5ab651931 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStreamableIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStreamableIntegrationTests.java @@ -5,23 +5,28 @@ package io.modelcontextprotocol; import java.time.Duration; +import java.util.Map; +import java.util.stream.Stream; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.provider.Arguments; + import org.springframework.http.server.reactive.HttpHandler; import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; import org.springframework.web.reactive.function.client.WebClient; import org.springframework.web.reactive.function.server.RouterFunctions; - -import com.fasterxml.jackson.databind.ObjectMapper; +import org.springframework.web.reactive.function.server.ServerRequest; import io.modelcontextprotocol.client.McpClient; import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; +import io.modelcontextprotocol.common.McpTransportContext; import io.modelcontextprotocol.server.McpServer; import io.modelcontextprotocol.server.McpServer.AsyncSpecification; import io.modelcontextprotocol.server.McpServer.SyncSpecification; +import io.modelcontextprotocol.server.McpTransportContextExtractor; import io.modelcontextprotocol.server.TestUtil; import io.modelcontextprotocol.server.transport.WebFluxStreamableServerTransportProvider; import reactor.netty.DisposableServer; @@ -38,6 +43,13 @@ class WebFluxStreamableIntegrationTests extends AbstractMcpClientServerIntegrati private WebFluxStreamableServerTransportProvider mcpStreamableServerTransportProvider; + static McpTransportContextExtractor TEST_CONTEXT_EXTRACTOR = (r) -> McpTransportContext + .create(Map.of("important", "value")); + + static Stream clientsForTesting() { + return Stream.of(Arguments.of("httpclient"), Arguments.of("webflux")); + } + @Override protected void prepareClients(int port, String mcpEndpoint) { @@ -69,8 +81,8 @@ protected SyncSpecification prepareSyncServerBuilder() { public void before() { this.mcpStreamableServerTransportProvider = WebFluxStreamableServerTransportProvider.builder() - .objectMapper(new ObjectMapper()) .messageEndpoint(CUSTOM_MESSAGE_ENDPOINT) + .contextExtractor(TEST_CONTEXT_EXTRACTOR) .build(); HttpHandler httpHandler = RouterFunctions diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientTests.java index f8a16c153..1a4eedd15 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientTests.java @@ -4,6 +4,8 @@ package io.modelcontextprotocol.client; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Timeout; import org.springframework.web.reactive.function.client.WebClient; import org.testcontainers.containers.GenericContainer; @@ -19,7 +21,7 @@ public class WebClientStreamableHttpAsyncClientTests extends AbstractMcpAsyncCli // Uses the https://github.com/tzolov/mcp-everything-server-docker-image @SuppressWarnings("resource") - GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") + static GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v3") .withCommand("node dist/index.js streamableHttp") .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) .withExposedPorts(3001) @@ -30,15 +32,15 @@ protected McpClientTransport createMcpTransport() { return WebClientStreamableHttpTransport.builder(WebClient.builder().baseUrl(host)).build(); } - @Override - protected void onStart() { + @BeforeAll + static void startContainer() { container.start(); int port = container.getMappedPort(3001); host = "http://" + container.getHost() + ":" + port; } - @Override - public void onClose() { + @AfterAll + static void stopContainer() { container.stop(); } diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpSyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpSyncClientTests.java index 5e9960d0e..16f1d79a6 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpSyncClientTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpSyncClientTests.java @@ -4,6 +4,8 @@ package io.modelcontextprotocol.client; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Timeout; import org.springframework.web.reactive.function.client.WebClient; import org.testcontainers.containers.GenericContainer; @@ -19,7 +21,7 @@ public class WebClientStreamableHttpSyncClientTests extends AbstractMcpSyncClien // Uses the https://github.com/tzolov/mcp-everything-server-docker-image @SuppressWarnings("resource") - GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") + static GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v3") .withCommand("node dist/index.js streamableHttp") .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) .withExposedPorts(3001) @@ -30,15 +32,15 @@ protected McpClientTransport createMcpTransport() { return WebClientStreamableHttpTransport.builder(WebClient.builder().baseUrl(host)).build(); } - @Override - protected void onStart() { + @BeforeAll + static void startContainer() { container.start(); int port = container.getMappedPort(3001); host = "http://" + container.getHost() + ":" + port; } - @Override - public void onClose() { + @AfterAll + static void stopContainer() { container.stop(); } diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java index 0edf4cd54..0a92beac4 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java @@ -6,6 +6,8 @@ import java.time.Duration; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Timeout; import org.springframework.web.reactive.function.client.WebClient; import org.testcontainers.containers.GenericContainer; @@ -26,7 +28,7 @@ class WebFluxSseMcpAsyncClientTests extends AbstractMcpAsyncClientTests { // Uses the https://github.com/tzolov/mcp-everything-server-docker-image @SuppressWarnings("resource") - GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") + static GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v3") .withCommand("node dist/index.js sse") .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) .withExposedPorts(3001) @@ -37,15 +39,15 @@ protected McpClientTransport createMcpTransport() { return WebFluxSseClientTransport.builder(WebClient.builder().baseUrl(host)).build(); } - @Override - protected void onStart() { + @BeforeAll + static void startContainer() { container.start(); int port = container.getMappedPort(3001); host = "http://" + container.getHost() + ":" + port; } - @Override - public void onClose() { + @AfterAll + static void stopContainer() { container.stop(); } diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java index 9b0959a35..0f35f9f0d 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java @@ -8,10 +8,11 @@ import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; import io.modelcontextprotocol.spec.McpClientTransport; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Timeout; import org.testcontainers.containers.GenericContainer; import org.testcontainers.containers.wait.strategy.Wait; - import org.springframework.web.reactive.function.client.WebClient; /** @@ -26,7 +27,7 @@ class WebFluxSseMcpSyncClientTests extends AbstractMcpSyncClientTests { // Uses the https://github.com/tzolov/mcp-everything-server-docker-image @SuppressWarnings("resource") - GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") + static GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v3") .withCommand("node dist/index.js sse") .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) .withExposedPorts(3001) @@ -37,15 +38,15 @@ protected McpClientTransport createMcpTransport() { return WebFluxSseClientTransport.builder(WebClient.builder().baseUrl(host)).build(); } - @Override - protected void onStart() { + @BeforeAll + static void startContainer() { container.start(); int port = container.getMappedPort(3001); host = "http://" + container.getHost() + ":" + port; } - @Override - protected void onClose() { + @AfterAll + static void stopContainer() { container.stop(); } diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransportErrorHandlingTest.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransportErrorHandlingTest.java index cdbb97e17..214fa489b 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransportErrorHandlingTest.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransportErrorHandlingTest.java @@ -20,7 +20,6 @@ import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; import org.springframework.web.reactive.function.client.WebClient; diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransportTest.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransportTest.java new file mode 100644 index 000000000..e2fcf91f7 --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransportTest.java @@ -0,0 +1,77 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ +package io.modelcontextprotocol.client.transport; + +import io.modelcontextprotocol.spec.McpSchema; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.wait.strategy.Wait; +import reactor.test.StepVerifier; + +import org.springframework.web.reactive.function.client.WebClient; + +class WebClientStreamableHttpTransportTest { + + static String host = "http://localhost:3001"; + + static WebClient.Builder builder; + + @SuppressWarnings("resource") + static GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v3") + .withCommand("node dist/index.js streamableHttp") + .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) + .withExposedPorts(3001) + .waitingFor(Wait.forHttp("/").forStatusCode(404)); + + @BeforeAll + static void startContainer() { + container.start(); + int port = container.getMappedPort(3001); + host = "http://" + container.getHost() + ":" + port; + builder = WebClient.builder().baseUrl(host); + } + + @AfterAll + static void stopContainer() { + container.stop(); + } + + @Test + void testCloseUninitialized() { + var transport = WebClientStreamableHttpTransport.builder(builder).build(); + + StepVerifier.create(transport.closeGracefully()).verifyComplete(); + + var initializeRequest = new McpSchema.InitializeRequest(McpSchema.LATEST_PROTOCOL_VERSION, + McpSchema.ClientCapabilities.builder().roots(true).build(), + new McpSchema.Implementation("MCP Client", "0.3.1")); + var testMessage = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, McpSchema.METHOD_INITIALIZE, + "test-id", initializeRequest); + + StepVerifier.create(transport.sendMessage(testMessage)) + .expectErrorMessage("MCP session has been closed") + .verify(); + } + + @Test + void testCloseInitialized() { + var transport = WebClientStreamableHttpTransport.builder(builder).build(); + + var initializeRequest = new McpSchema.InitializeRequest(McpSchema.LATEST_PROTOCOL_VERSION, + McpSchema.ClientCapabilities.builder().roots(true).build(), + new McpSchema.Implementation("MCP Client", "0.3.1")); + var testMessage = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, McpSchema.METHOD_INITIALIZE, + "test-id", initializeRequest); + + StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); + StepVerifier.create(transport.closeGracefully()).verifyComplete(); + + StepVerifier.create(transport.sendMessage(testMessage)) + .expectErrorMatches(err -> err.getMessage().matches("MCP session with ID [a-zA-Z0-9-]* has been closed")) + .verify(); + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java index 1cf5dffe2..1150e47f5 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java @@ -11,9 +11,13 @@ import java.util.function.Function; import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.jackson.JacksonMcpJsonMapper; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest; +import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; @@ -27,6 +31,7 @@ import org.springframework.http.codec.ServerSentEvent; import org.springframework.web.reactive.function.client.WebClient; +import static io.modelcontextprotocol.util.McpJsonMapperUtils.JSON_MAPPER; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -42,7 +47,7 @@ class WebFluxSseClientTransportTests { static String host = "http://localhost:3001"; @SuppressWarnings("resource") - GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") + static GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v3") .withCommand("node dist/index.js sse") .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) .withExposedPorts(3001) @@ -52,8 +57,6 @@ class WebFluxSseClientTransportTests { private WebClient.Builder webClientBuilder; - private ObjectMapper objectMapper; - // Test class to access protected methods static class TestSseClientTransport extends WebFluxSseClientTransport { @@ -61,8 +64,8 @@ static class TestSseClientTransport extends WebFluxSseClientTransport { private Sinks.Many> events = Sinks.many().unicast().onBackpressureBuffer(); - public TestSseClientTransport(WebClient.Builder webClientBuilder, ObjectMapper objectMapper) { - super(webClientBuilder, objectMapper); + public TestSseClientTransport(WebClient.Builder webClientBuilder, McpJsonMapper jsonMapper) { + super(webClientBuilder, jsonMapper); } @Override @@ -95,18 +98,22 @@ public void simulateMessageEvent(String jsonMessage) { } - void startContainer() { + @BeforeAll + static void startContainer() { container.start(); int port = container.getMappedPort(3001); host = "http://" + container.getHost() + ":" + port; } + @AfterAll + static void cleanup() { + container.stop(); + } + @BeforeEach void setUp() { - startContainer(); webClientBuilder = WebClient.builder().baseUrl(host); - objectMapper = new ObjectMapper(); - transport = new TestSseClientTransport(webClientBuilder, objectMapper); + transport = new TestSseClientTransport(webClientBuilder, JSON_MAPPER); transport.connect(Function.identity()).block(); } @@ -115,11 +122,6 @@ void afterEach() { if (transport != null) { assertThatCode(() -> transport.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); } - cleanup(); - } - - void cleanup() { - container.stop(); } @Test @@ -129,12 +131,13 @@ void testEndpointEventHandling() { @Test void constructorValidation() { - assertThatThrownBy(() -> new WebFluxSseClientTransport(null)).isInstanceOf(IllegalArgumentException.class) + assertThatThrownBy(() -> new WebFluxSseClientTransport(null, JSON_MAPPER)) + .isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("WebClient.Builder must not be null"); assertThatThrownBy(() -> new WebFluxSseClientTransport(webClientBuilder, null)) .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("ObjectMapper must not be null"); + .hasMessageContaining("jsonMapper must not be null"); } @Test @@ -146,7 +149,7 @@ void testBuilderPattern() { // Test builder with custom ObjectMapper ObjectMapper customMapper = new ObjectMapper(); WebFluxSseClientTransport transport2 = WebFluxSseClientTransport.builder(webClientBuilder) - .objectMapper(customMapper) + .jsonMapper(new JacksonMcpJsonMapper(customMapper)) .build(); assertThatCode(() -> transport2.closeGracefully().block()).doesNotThrowAnyException(); @@ -158,7 +161,6 @@ void testBuilderPattern() { // Test builder with all custom parameters WebFluxSseClientTransport transport4 = WebFluxSseClientTransport.builder(webClientBuilder) - .objectMapper(customMapper) .sseEndpoint("/custom-sse") .build(); assertThatCode(() -> transport4.closeGracefully().block()).doesNotThrowAnyException(); diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/common/AsyncServerMcpTransportContextIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/common/AsyncServerMcpTransportContextIntegrationTests.java new file mode 100644 index 000000000..3db0bbd3a --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/common/AsyncServerMcpTransportContextIntegrationTests.java @@ -0,0 +1,269 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.common; + +import java.util.Map; +import java.util.function.BiFunction; + +import io.modelcontextprotocol.client.McpAsyncClient; +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; +import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; +import io.modelcontextprotocol.server.McpAsyncServerExchange; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.McpServerFeatures; +import io.modelcontextprotocol.server.McpStatelessServerFeatures; +import io.modelcontextprotocol.server.McpTransportContextExtractor; +import io.modelcontextprotocol.server.TestUtil; +import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider; +import io.modelcontextprotocol.server.transport.WebFluxStatelessServerTransport; +import io.modelcontextprotocol.server.transport.WebFluxStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.McpSchema; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import reactor.core.publisher.Mono; +import reactor.netty.DisposableServer; +import reactor.netty.http.server.HttpServer; +import reactor.test.StepVerifier; + +import org.springframework.http.server.reactive.HttpHandler; +import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; +import org.springframework.web.reactive.function.client.ClientRequest; +import org.springframework.web.reactive.function.client.ExchangeFilterFunction; +import org.springframework.web.reactive.function.client.WebClient; +import org.springframework.web.reactive.function.server.RouterFunction; +import org.springframework.web.reactive.function.server.RouterFunctions; +import org.springframework.web.reactive.function.server.ServerRequest; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for {@link McpTransportContext} propagation between MCP clients and + * async servers using Spring WebFlux infrastructure. + * + *

+ * This test class validates the end-to-end flow of transport context propagation in MCP + * communication for asynchronous client and server implementations. It tests various + * combinations of client types and server transport mechanisms (stateless, streamable, + * SSE) to ensure proper context handling across different configurations. + * + *

Context Propagation Flow

+ *
    + *
  1. Client sets a value in its transport context via thread-local Reactor context
  2. + *
  3. Client-side context provider extracts the value and adds it as an HTTP header to + * the request
  4. + *
  5. Server-side context extractor reads the header from the incoming request
  6. + *
  7. Server handler receives the extracted context and returns the value as the tool + * call result
  8. + *
  9. Test verifies the round-trip context propagation was successful
  10. + *
+ * + * @author Daniel Garnier-Moiroux + * @author Christian Tzolov + */ +@Timeout(15) +public class AsyncServerMcpTransportContextIntegrationTests { + + private static final int PORT = TestUtil.findAvailablePort(); + + private static final String HEADER_NAME = "x-test"; + + // Async client context provider + ExchangeFilterFunction asyncClientContextProvider = (request, next) -> Mono.deferContextual(ctx -> { + var transportContext = ctx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY); + // // do stuff with the context + var headerValue = transportContext.get("client-side-header-value"); + if (headerValue == null) { + return next.exchange(request); + } + var reqWithHeader = ClientRequest.from(request).header(HEADER_NAME, headerValue.toString()).build(); + return next.exchange(reqWithHeader); + }); + + // Tools + private final McpSchema.Tool tool = McpSchema.Tool.builder() + .name("test-tool") + .description("return the value of the x-test header from call tool request") + .build(); + + private final BiFunction> asyncStatelessHandler = ( + transportContext, request) -> { + return Mono + .just(new McpSchema.CallToolResult(transportContext.get("server-side-header-value").toString(), null)); + }; + + private final BiFunction> asyncStatefulHandler = ( + exchange, request) -> { + return asyncStatelessHandler.apply(exchange.transportContext(), request); + }; + + // Server context extractor + private final McpTransportContextExtractor serverContextExtractor = (ServerRequest r) -> { + var headerValue = r.headers().firstHeader(HEADER_NAME); + return headerValue != null ? McpTransportContext.create(Map.of("server-side-header-value", headerValue)) + : McpTransportContext.EMPTY; + }; + + // Server transports + private final WebFluxStatelessServerTransport statelessServerTransport = WebFluxStatelessServerTransport.builder() + .contextExtractor(serverContextExtractor) + .build(); + + private final WebFluxStreamableServerTransportProvider streamableServerTransport = WebFluxStreamableServerTransportProvider + .builder() + .contextExtractor(serverContextExtractor) + .build(); + + private final WebFluxSseServerTransportProvider sseServerTransport = WebFluxSseServerTransportProvider.builder() + .contextExtractor(serverContextExtractor) + .messageEndpoint("/mcp/message") + .build(); + + // Async clients + private final McpAsyncClient asyncStreamableClient = McpClient + .async(WebClientStreamableHttpTransport + .builder(WebClient.builder().baseUrl("http://localhost:" + PORT).filter(asyncClientContextProvider)) + .build()) + .build(); + + private final McpAsyncClient asyncSseClient = McpClient + .async(WebFluxSseClientTransport + .builder(WebClient.builder().baseUrl("http://localhost:" + PORT).filter(asyncClientContextProvider)) + .build()) + .build(); + + private DisposableServer httpServer; + + @AfterEach + public void after() { + if (statelessServerTransport != null) { + statelessServerTransport.closeGracefully().block(); + } + if (streamableServerTransport != null) { + streamableServerTransport.closeGracefully().block(); + } + if (sseServerTransport != null) { + sseServerTransport.closeGracefully().block(); + } + if (asyncStreamableClient != null) { + asyncStreamableClient.closeGracefully().block(); + } + if (asyncSseClient != null) { + asyncSseClient.closeGracefully().block(); + } + stopHttpServer(); + } + + @Test + void asyncClientStatelessServer() { + + startHttpServer(statelessServerTransport.getRouterFunction()); + + var mcpServer = McpServer.async(statelessServerTransport) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpStatelessServerFeatures.AsyncToolSpecification(tool, asyncStatelessHandler)) + .build(); + + StepVerifier.create(asyncStreamableClient.initialize()).assertNext(initResult -> { + assertThat(initResult).isNotNull(); + }).verifyComplete(); + + // Test tool call with context + StepVerifier + .create(asyncStreamableClient.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, + McpTransportContext.create(Map.of("client-side-header-value", "some important value"))))) + .assertNext(response -> { + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + }) + .verifyComplete(); + + mcpServer.close(); + } + + @Test + void asyncClientStreamableServer() { + + startHttpServer(streamableServerTransport.getRouterFunction()); + + var mcpServer = McpServer.async(streamableServerTransport) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpServerFeatures.AsyncToolSpecification(tool, null, asyncStatefulHandler)) + .build(); + + StepVerifier.create(asyncStreamableClient.initialize()).assertNext(initResult -> { + assertThat(initResult).isNotNull(); + }).verifyComplete(); + + // Test tool call with context + StepVerifier + .create(asyncStreamableClient.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, + McpTransportContext.create(Map.of("client-side-header-value", "some important value"))))) + .assertNext(response -> { + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + }) + .verifyComplete(); + + mcpServer.close(); + } + + @Test + void asyncClientSseServer() { + + startHttpServer(sseServerTransport.getRouterFunction()); + + var mcpServer = McpServer.async(sseServerTransport) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpServerFeatures.AsyncToolSpecification(tool, null, asyncStatefulHandler)) + .build(); + + StepVerifier.create(asyncSseClient.initialize()).assertNext(initResult -> { + assertThat(initResult).isNotNull(); + }).verifyComplete(); + + // Test tool call with context + StepVerifier + .create(asyncSseClient.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, + McpTransportContext.create(Map.of("client-side-header-value", "some important value"))))) + .assertNext(response -> { + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + }) + .verifyComplete(); + + mcpServer.close(); + } + + private void startHttpServer(RouterFunction routerFunction) { + + HttpHandler httpHandler = RouterFunctions.toHttpHandler(routerFunction); + ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); + this.httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); + } + + private void stopHttpServer() { + if (httpServer != null) { + httpServer.disposeNow(); + } + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/common/SyncServerMcpTransportContextIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/common/SyncServerMcpTransportContextIntegrationTests.java new file mode 100644 index 000000000..94e16e73e --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/common/SyncServerMcpTransportContextIntegrationTests.java @@ -0,0 +1,269 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.common; + +import java.util.Map; +import java.util.function.BiFunction; +import java.util.function.Supplier; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.McpSyncClient; +import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; +import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.McpServerFeatures; +import io.modelcontextprotocol.server.McpStatelessServerFeatures; +import io.modelcontextprotocol.server.McpSyncServerExchange; +import io.modelcontextprotocol.server.McpTransportContextExtractor; +import io.modelcontextprotocol.server.TestUtil; +import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider; +import io.modelcontextprotocol.server.transport.WebFluxStatelessServerTransport; +import io.modelcontextprotocol.server.transport.WebFluxStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.McpSchema; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import reactor.core.publisher.Mono; +import reactor.netty.DisposableServer; +import reactor.netty.http.server.HttpServer; + +import org.springframework.http.server.reactive.HttpHandler; +import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; +import org.springframework.web.reactive.function.client.ClientRequest; +import org.springframework.web.reactive.function.client.WebClient; +import org.springframework.web.reactive.function.server.RouterFunction; +import org.springframework.web.reactive.function.server.RouterFunctions; +import org.springframework.web.reactive.function.server.ServerRequest; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for {@link McpTransportContext} propagation between MCP client and + * server using synchronous operations in a Spring WebFlux environment. + *

+ * This test class validates the end-to-end flow of transport context propagation across + * different WebFlux-based MCP transport implementations + * + *

+ * The test scenario follows these steps: + *

    + *
  1. The client stores a value in a thread-local variable
  2. + *
  3. The client's transport context provider reads this value and includes it in the MCP + * context
  4. + *
  5. A WebClient filter extracts the context value and adds it as an HTTP header + * (x-test)
  6. + *
  7. The server's {@link McpTransportContextExtractor} reads the header from the + * request
  8. + *
  9. The server returns the header value as the tool call result, validating the + * round-trip
  10. + *
+ * + *

+ * This test demonstrates how custom context can be propagated through HTTP headers in a + * reactive WebFlux environment, enabling features like authentication tokens, correlation + * IDs, or other metadata to flow between MCP client and server. + * + * @author Daniel Garnier-Moiroux + * @author Christian Tzolov + * @since 1.0.0 + * @see McpTransportContext + * @see McpTransportContextExtractor + * @see WebFluxStatelessServerTransport + * @see WebFluxStreamableServerTransportProvider + * @see WebFluxSseServerTransportProvider + */ +@Timeout(15) +public class SyncServerMcpTransportContextIntegrationTests { + + private static final int PORT = TestUtil.findAvailablePort(); + + private static final ThreadLocal CLIENT_SIDE_HEADER_VALUE_HOLDER = new ThreadLocal<>(); + + private static final String HEADER_NAME = "x-test"; + + private final Supplier clientContextProvider = () -> { + var headerValue = CLIENT_SIDE_HEADER_VALUE_HOLDER.get(); + return headerValue != null ? McpTransportContext.create(Map.of("client-side-header-value", headerValue)) + : McpTransportContext.EMPTY; + }; + + private final BiFunction statelessHandler = ( + transportContext, request) -> { + return new McpSchema.CallToolResult(transportContext.get("server-side-header-value").toString(), null); + }; + + private final BiFunction statefulHandler = ( + exchange, request) -> statelessHandler.apply(exchange.transportContext(), request); + + private final McpTransportContextExtractor serverContextExtractor = (ServerRequest r) -> { + var headerValue = r.headers().firstHeader(HEADER_NAME); + return headerValue != null ? McpTransportContext.create(Map.of("server-side-header-value", headerValue)) + : McpTransportContext.EMPTY; + }; + + private final WebFluxStatelessServerTransport statelessServerTransport = WebFluxStatelessServerTransport.builder() + .contextExtractor(serverContextExtractor) + .build(); + + private final WebFluxStreamableServerTransportProvider streamableServerTransport = WebFluxStreamableServerTransportProvider + .builder() + .contextExtractor(serverContextExtractor) + .build(); + + private final WebFluxSseServerTransportProvider sseServerTransport = WebFluxSseServerTransportProvider.builder() + .contextExtractor(serverContextExtractor) + .messageEndpoint("/mcp/message") + .build(); + + private final McpSyncClient streamableClient = McpClient + .sync(WebClientStreamableHttpTransport.builder(WebClient.builder() + .baseUrl("http://localhost:" + PORT) + .filter((request, next) -> Mono.deferContextual(ctx -> { + var context = ctx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY); + // // do stuff with the context + var headerValue = context.get("client-side-header-value"); + if (headerValue == null) { + return next.exchange(request); + } + var reqWithHeader = ClientRequest.from(request).header(HEADER_NAME, headerValue.toString()).build(); + return next.exchange(reqWithHeader); + }))).build()) + .transportContextProvider(clientContextProvider) + .build(); + + private final McpSyncClient sseClient = McpClient.sync(WebFluxSseClientTransport.builder(WebClient.builder() + .baseUrl("http://localhost:" + PORT) + .filter((request, next) -> Mono.deferContextual(ctx -> { + var context = ctx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY); + // // do stuff with the context + var headerValue = context.get("client-side-header-value"); + if (headerValue == null) { + return next.exchange(request); + } + var reqWithHeader = ClientRequest.from(request).header(HEADER_NAME, headerValue.toString()).build(); + return next.exchange(reqWithHeader); + }))).build()).transportContextProvider(clientContextProvider).build(); + + private final McpSchema.Tool tool = McpSchema.Tool.builder() + .name("test-tool") + .description("return the value of the x-test header from call tool request") + .build(); + + private DisposableServer httpServer; + + @AfterEach + public void after() { + CLIENT_SIDE_HEADER_VALUE_HOLDER.remove(); + if (statelessServerTransport != null) { + statelessServerTransport.closeGracefully().block(); + } + if (streamableServerTransport != null) { + streamableServerTransport.closeGracefully().block(); + } + if (sseServerTransport != null) { + sseServerTransport.closeGracefully().block(); + } + if (streamableClient != null) { + streamableClient.closeGracefully(); + } + if (sseClient != null) { + sseClient.closeGracefully(); + } + stopHttpServer(); + } + + @Test + void statelessServer() { + + startHttpServer(statelessServerTransport.getRouterFunction()); + + var mcpServer = McpServer.sync(statelessServerTransport) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpStatelessServerFeatures.SyncToolSpecification(tool, statelessHandler)) + .build(); + + McpSchema.InitializeResult initResult = streamableClient.initialize(); + assertThat(initResult).isNotNull(); + + CLIENT_SIDE_HEADER_VALUE_HOLDER.set("some important value"); + McpSchema.CallToolResult response = streamableClient + .callTool(new McpSchema.CallToolRequest("test-tool", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + + mcpServer.close(); + } + + @Test + void streamableServer() { + + startHttpServer(streamableServerTransport.getRouterFunction()); + + var mcpServer = McpServer.sync(streamableServerTransport) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpServerFeatures.SyncToolSpecification(tool, null, statefulHandler)) + .build(); + + McpSchema.InitializeResult initResult = streamableClient.initialize(); + assertThat(initResult).isNotNull(); + + CLIENT_SIDE_HEADER_VALUE_HOLDER.set("some important value"); + McpSchema.CallToolResult response = streamableClient + .callTool(new McpSchema.CallToolRequest("test-tool", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + + mcpServer.close(); + } + + @Test + void sseServer() { + startHttpServer(sseServerTransport.getRouterFunction()); + + var mcpServer = McpServer.sync(sseServerTransport) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpServerFeatures.SyncToolSpecification(tool, null, statefulHandler)) + .build(); + + McpSchema.InitializeResult initResult = sseClient.initialize(); + assertThat(initResult).isNotNull(); + + CLIENT_SIDE_HEADER_VALUE_HOLDER.set("some important value"); + McpSchema.CallToolResult response = sseClient.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + + mcpServer.close(); + } + + private void startHttpServer(RouterFunction routerFunction) { + + HttpHandler httpHandler = RouterFunctions.toHttpHandler(routerFunction); + ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); + this.httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); + } + + private void stopHttpServer() { + if (httpServer != null) { + httpServer.disposeNow(); + } + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerTests.java index a3bdf10b0..fe0314687 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerTests.java @@ -4,7 +4,6 @@ package io.modelcontextprotocol.server; -import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider; import io.modelcontextprotocol.spec.McpServerTransportProvider; import org.junit.jupiter.api.Timeout; @@ -30,8 +29,7 @@ class WebFluxSseMcpAsyncServerTests extends AbstractMcpAsyncServerTests { private DisposableServer httpServer; private McpServerTransportProvider createMcpTransportProvider() { - var transportProvider = new WebFluxSseServerTransportProvider.Builder().objectMapper(new ObjectMapper()) - .messageEndpoint(MESSAGE_ENDPOINT) + var transportProvider = new WebFluxSseServerTransportProvider.Builder().messageEndpoint(MESSAGE_ENDPOINT) .build(); HttpHandler httpHandler = RouterFunctions.toHttpHandler(transportProvider.getRouterFunction()); diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerTests.java index 3e28e96b8..67ef90bdf 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerTests.java @@ -4,7 +4,6 @@ package io.modelcontextprotocol.server; -import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider; import io.modelcontextprotocol.spec.McpServerTransportProvider; import org.junit.jupiter.api.Timeout; @@ -37,9 +36,7 @@ protected McpServer.SyncSpecification prepareSyncServerBuilder() { } private McpServerTransportProvider createMcpTransportProvider() { - transportProvider = new WebFluxSseServerTransportProvider.Builder().objectMapper(new ObjectMapper()) - .messageEndpoint(MESSAGE_ENDPOINT) - .build(); + transportProvider = new WebFluxSseServerTransportProvider.Builder().messageEndpoint(MESSAGE_ENDPOINT).build(); return transportProvider; } diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxStreamableMcpAsyncServerTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxStreamableMcpAsyncServerTests.java index 959f2f472..9b5a80f16 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxStreamableMcpAsyncServerTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxStreamableMcpAsyncServerTests.java @@ -4,7 +4,6 @@ package io.modelcontextprotocol.server; -import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.server.transport.WebFluxStreamableServerTransportProvider; import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; import org.junit.jupiter.api.Timeout; @@ -32,7 +31,6 @@ class WebFluxStreamableMcpAsyncServerTests extends AbstractMcpAsyncServerTests { private McpStreamableServerTransportProvider createMcpTransportProvider() { var transportProvider = WebFluxStreamableServerTransportProvider.builder() - .objectMapper(new ObjectMapper()) .messageEndpoint(MESSAGE_ENDPOINT) .build(); diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxStreamableMcpSyncServerTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxStreamableMcpSyncServerTests.java index 3396d489c..6a47ba3ae 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxStreamableMcpSyncServerTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxStreamableMcpSyncServerTests.java @@ -4,7 +4,6 @@ package io.modelcontextprotocol.server; -import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.server.transport.WebFluxStreamableServerTransportProvider; import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; import org.junit.jupiter.api.Timeout; @@ -32,7 +31,6 @@ class WebFluxStreamableMcpSyncServerTests extends AbstractMcpSyncServerTests { private McpStreamableServerTransportProvider createMcpTransportProvider() { var transportProvider = WebFluxStreamableServerTransportProvider.builder() - .objectMapper(new ObjectMapper()) .messageEndpoint(MESSAGE_ENDPOINT) .build(); diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/utils/McpJsonMapperUtils.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/utils/McpJsonMapperUtils.java new file mode 100644 index 000000000..67347573c --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/utils/McpJsonMapperUtils.java @@ -0,0 +1,12 @@ +package io.modelcontextprotocol.utils; + +import io.modelcontextprotocol.json.McpJsonMapper; + +public final class McpJsonMapperUtils { + + private McpJsonMapperUtils() { + } + + public static final McpJsonMapper JSON_MAPPER = McpJsonMapper.createDefault(); + +} \ No newline at end of file diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/utils/McpTestRequestRecordingExchangeFilterFunction.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/utils/McpTestRequestRecordingExchangeFilterFunction.java new file mode 100644 index 000000000..55129d481 --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/utils/McpTestRequestRecordingExchangeFilterFunction.java @@ -0,0 +1,54 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package io.modelcontextprotocol.utils; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import reactor.core.publisher.Mono; + +import org.springframework.http.HttpMethod; +import org.springframework.web.reactive.function.server.HandlerFilterFunction; +import org.springframework.web.reactive.function.server.HandlerFunction; +import org.springframework.web.reactive.function.server.ServerRequest; +import org.springframework.web.reactive.function.server.ServerResponse; + +/** + * Simple {@link HandlerFilterFunction} which records calls made to an MCP server. + * + * @author Daniel Garnier-Moiroux + */ +public class McpTestRequestRecordingExchangeFilterFunction implements HandlerFilterFunction { + + private final List calls = new ArrayList<>(); + + @Override + public Mono filter(ServerRequest request, HandlerFunction next) { + Map headers = request.headers() + .asHttpHeaders() + .keySet() + .stream() + .collect(Collectors.toMap(String::toLowerCase, k -> String.join(",", request.headers().header(k)))); + + var cr = request.bodyToMono(String.class).defaultIfEmpty("").map(body -> { + this.calls.add(new Call(request.method(), headers, body)); + return ServerRequest.from(request).body(body).build(); + }); + + return cr.flatMap(next::handle); + + } + + public List getCalls() { + return List.copyOf(calls); + } + + public record Call(HttpMethod method, Map headers, String body) { + + } + +} diff --git a/mcp-spring/mcp-spring-webmvc/pom.xml b/mcp-spring/mcp-spring-webmvc/pom.xml index ea262d3a1..df18b1b8b 100644 --- a/mcp-spring/mcp-spring-webmvc/pom.xml +++ b/mcp-spring/mcp-spring-webmvc/pom.xml @@ -6,7 +6,7 @@ io.modelcontextprotocol.sdk mcp-parent - 0.12.0-SNAPSHOT + 0.18.0-SNAPSHOT ../../pom.xml mcp-spring-webmvc @@ -22,10 +22,16 @@ - + + io.modelcontextprotocol.sdk + mcp-json-jackson2 + 0.18.0-SNAPSHOT + + + io.modelcontextprotocol.sdk mcp - 0.12.0-SNAPSHOT + 0.18.0-SNAPSHOT @@ -37,14 +43,14 @@ io.modelcontextprotocol.sdk mcp-test - 0.12.0-SNAPSHOT + 0.18.0-SNAPSHOT test - + io.modelcontextprotocol.sdk mcp-spring-webflux - 0.12.0-SNAPSHOT + 0.18.0-SNAPSHOT test diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java index ff452ca74..6c35de56d 100644 --- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java @@ -7,21 +7,21 @@ import java.io.IOException; import java.time.Duration; import java.util.List; -import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.locks.ReentrantLock; -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.TypeRef; +import io.modelcontextprotocol.server.McpTransportContextExtractor; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpServerSession; import io.modelcontextprotocol.spec.McpServerTransport; import io.modelcontextprotocol.spec.McpServerTransportProvider; import io.modelcontextprotocol.spec.ProtocolVersions; -import io.modelcontextprotocol.spec.McpServerSession; import io.modelcontextprotocol.util.Assert; import io.modelcontextprotocol.util.KeepAliveScheduler; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; @@ -33,6 +33,7 @@ import org.springframework.web.servlet.function.ServerRequest; import org.springframework.web.servlet.function.ServerResponse; import org.springframework.web.servlet.function.ServerResponse.SseBuilder; +import org.springframework.web.util.UriComponentsBuilder; /** * Server-side implementation of the Model Context Protocol (MCP) transport layer using @@ -84,12 +85,14 @@ public class WebMvcSseServerTransportProvider implements McpServerTransportProvi */ public static final String ENDPOINT_EVENT_TYPE = "endpoint"; + public static final String SESSION_ID = "sessionId"; + /** * Default SSE endpoint path as specified by the MCP transport specification. */ public static final String DEFAULT_SSE_ENDPOINT = "/sse"; - private final ObjectMapper objectMapper; + private final McpJsonMapper jsonMapper; private final String messageEndpoint; @@ -106,6 +109,8 @@ public class WebMvcSseServerTransportProvider implements McpServerTransportProvi */ private final ConcurrentHashMap sessions = new ConcurrentHashMap<>(); + private McpTransportContextExtractor contextExtractor; + /** * Flag indicating if the transport is shutting down. */ @@ -113,43 +118,9 @@ public class WebMvcSseServerTransportProvider implements McpServerTransportProvi private KeepAliveScheduler keepAliveScheduler; - /** - * Constructs a new WebMvcSseServerTransportProvider instance with the default SSE - * endpoint. - * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization - * of messages. - * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC - * messages via HTTP POST. This endpoint will be communicated to clients through the - * SSE connection's initial endpoint event. - * @throws IllegalArgumentException if either objectMapper or messageEndpoint is null - * @deprecated Use the builder {@link #builder()} instead for better configuration - * options. - */ - @Deprecated - public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint) { - this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT); - } - - /** - * Constructs a new WebMvcSseServerTransportProvider instance. - * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization - * of messages. - * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC - * messages via HTTP POST. This endpoint will be communicated to clients through the - * SSE connection's initial endpoint event. - * @param sseEndpoint The endpoint URI where clients establish their SSE connections. - * @throws IllegalArgumentException if any parameter is null - * @deprecated Use the builder {@link #builder()} instead for better configuration - * options. - */ - @Deprecated - public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) { - this(objectMapper, "", messageEndpoint, sseEndpoint); - } - /** * Constructs a new WebMvcSseServerTransportProvider instance. - * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + * @param jsonMapper The McpJsonMapper to use for JSON serialization/deserialization * of messages. * @param baseUrl The base URL for the message endpoint, used to construct the full * endpoint URL for clients. @@ -157,43 +128,25 @@ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messag * messages via HTTP POST. This endpoint will be communicated to clients through the * SSE connection's initial endpoint event. * @param sseEndpoint The endpoint URI where clients establish their SSE connections. + * @param keepAliveInterval The interval for sending keep-alive messages to clients. + * @param contextExtractor The contextExtractor to fill in a + * {@link McpTransportContext}. * @throws IllegalArgumentException if any parameter is null - * @deprecated Use the builder {@link #builder()} instead for better configuration - * options. */ - @Deprecated - public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, - String sseEndpoint) { - this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, null); - } - - /** - * Constructs a new WebMvcSseServerTransportProvider instance. - * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization - * of messages. - * @param baseUrl The base URL for the message endpoint, used to construct the full - * endpoint URL for clients. - * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC - * messages via HTTP POST. This endpoint will be communicated to clients through the - * SSE connection's initial endpoint event. - * @param sseEndpoint The endpoint URI where clients establish their SSE connections. - * * @param keepAliveInterval The interval for sending keep-alive messages to - * @throws IllegalArgumentException if any parameter is null - * @deprecated Use the builder {@link #builder()} instead for better configuration - * options. - */ - @Deprecated - public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, - String sseEndpoint, Duration keepAliveInterval) { - Assert.notNull(objectMapper, "ObjectMapper must not be null"); + private WebMvcSseServerTransportProvider(McpJsonMapper jsonMapper, String baseUrl, String messageEndpoint, + String sseEndpoint, Duration keepAliveInterval, + McpTransportContextExtractor contextExtractor) { + Assert.notNull(jsonMapper, "McpJsonMapper must not be null"); Assert.notNull(baseUrl, "Message base URL must not be null"); Assert.notNull(messageEndpoint, "Message endpoint must not be null"); Assert.notNull(sseEndpoint, "SSE endpoint must not be null"); + Assert.notNull(contextExtractor, "Context extractor must not be null"); - this.objectMapper = objectMapper; + this.jsonMapper = jsonMapper; this.baseUrl = baseUrl; this.messageEndpoint = messageEndpoint; this.sseEndpoint = sseEndpoint; + this.contextExtractor = contextExtractor; this.routerFunction = RouterFunctions.route() .GET(this.sseEndpoint, this::handleSseConnection) .POST(this.messageEndpoint, this::handleMessage) @@ -302,41 +255,46 @@ private ServerResponse handleSseConnection(ServerRequest request) { return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down"); } - String sessionId = UUID.randomUUID().toString(); - logger.debug("Creating new SSE connection for session: {}", sessionId); - // Send initial endpoint event - try { - return ServerResponse.sse(sseBuilder -> { - sseBuilder.onComplete(() -> { - logger.debug("SSE connection completed for session: {}", sessionId); - sessions.remove(sessionId); - }); - sseBuilder.onTimeout(() -> { - logger.debug("SSE connection timed out for session: {}", sessionId); - sessions.remove(sessionId); - }); - - WebMvcMcpSessionTransport sessionTransport = new WebMvcMcpSessionTransport(sessionId, sseBuilder); - McpServerSession session = sessionFactory.create(sessionTransport); - this.sessions.put(sessionId, session); + return ServerResponse.sse(sseBuilder -> { + WebMvcMcpSessionTransport sessionTransport = new WebMvcMcpSessionTransport(sseBuilder); + McpServerSession session = sessionFactory.create(sessionTransport); + String sessionId = session.getId(); + logger.debug("Creating new SSE connection for session: {}", sessionId); + sseBuilder.onComplete(() -> { + logger.debug("SSE connection completed for session: {}", sessionId); + sessions.remove(sessionId); + }); + sseBuilder.onTimeout(() -> { + logger.debug("SSE connection timed out for session: {}", sessionId); + sessions.remove(sessionId); + }); + this.sessions.put(sessionId, session); - try { - sseBuilder.id(sessionId) - .event(ENDPOINT_EVENT_TYPE) - .data(this.baseUrl + this.messageEndpoint + "?sessionId=" + sessionId); - } - catch (Exception e) { - logger.error("Failed to send initial endpoint event: {}", e.getMessage()); - sseBuilder.error(e); - } - }, Duration.ZERO); - } - catch (Exception e) { - logger.error("Failed to send initial endpoint event to session {}: {}", sessionId, e.getMessage()); - sessions.remove(sessionId); - return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).build(); - } + try { + sseBuilder.event(ENDPOINT_EVENT_TYPE).data(buildEndpointUrl(sessionId)); + } + catch (Exception e) { + logger.error("Failed to send initial endpoint event: {}", e.getMessage()); + this.sessions.remove(sessionId); + sseBuilder.error(e); + } + }, Duration.ZERO); + } + + /** + * Constructs the full message endpoint URL by combining the base URL, message path, + * and the required session_id query parameter. + * @param sessionId the unique session identifier + * @return the fully qualified endpoint URL as a string + */ + private String buildEndpointUrl(String sessionId) { + // for WebMVC compatibility + return UriComponentsBuilder.fromUriString(this.baseUrl) + .path(this.messageEndpoint) + .queryParam(SESSION_ID, sessionId) + .build() + .toUriString(); } /** @@ -355,11 +313,11 @@ private ServerResponse handleMessage(ServerRequest request) { return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down"); } - if (request.param("sessionId").isEmpty()) { + if (request.param(SESSION_ID).isEmpty()) { return ServerResponse.badRequest().body(new McpError("Session ID missing in message endpoint")); } - String sessionId = request.param("sessionId").get(); + String sessionId = request.param(SESSION_ID).get(); McpServerSession session = sessions.get(sessionId); if (session == null) { @@ -367,11 +325,16 @@ private ServerResponse handleMessage(ServerRequest request) { } try { + final McpTransportContext transportContext = this.contextExtractor.extract(request); + String body = request.body(String.class); - McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body); + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(jsonMapper, body); // Process the message through the session's handle method - session.handle(message).block(); // Block for WebMVC compatibility + session.handle(message).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)).block(); // Block + // for + // WebMVC + // compatibility return ServerResponse.ok().build(); } @@ -391,8 +354,6 @@ private ServerResponse handleMessage(ServerRequest request) { */ private class WebMvcMcpSessionTransport implements McpServerTransport { - private final String sessionId; - private final SseBuilder sseBuilder; /** @@ -402,14 +363,11 @@ private class WebMvcMcpSessionTransport implements McpServerTransport { private final ReentrantLock sseBuilderLock = new ReentrantLock(); /** - * Creates a new session transport with the specified ID and SSE builder. - * @param sessionId The unique identifier for this session + * Creates a new session transport with the specified SSE builder. * @param sseBuilder The SSE builder for sending server events to the client */ - WebMvcMcpSessionTransport(String sessionId, SseBuilder sseBuilder) { - this.sessionId = sessionId; + WebMvcMcpSessionTransport(SseBuilder sseBuilder) { this.sseBuilder = sseBuilder; - logger.debug("Session transport {} initialized with SSE builder", sessionId); } /** @@ -422,12 +380,11 @@ public Mono sendMessage(McpSchema.JSONRPCMessage message) { return Mono.fromRunnable(() -> { sseBuilderLock.lock(); try { - String jsonText = objectMapper.writeValueAsString(message); - sseBuilder.id(sessionId).event(MESSAGE_EVENT_TYPE).data(jsonText); - logger.debug("Message sent to session {}", sessionId); + String jsonText = jsonMapper.writeValueAsString(message); + sseBuilder.event(MESSAGE_EVENT_TYPE).data(jsonText); } catch (Exception e) { - logger.error("Failed to send message to session {}: {}", sessionId, e.getMessage()); + logger.error("Failed to send message: {}", e.getMessage()); sseBuilder.error(e); } finally { @@ -437,15 +394,15 @@ public Mono sendMessage(McpSchema.JSONRPCMessage message) { } /** - * Converts data from one type to another using the configured ObjectMapper. + * Converts data from one type to another using the configured McpJsonMapper. * @param data The source data object to convert * @param typeRef The target type reference - * @return The converted object of type T * @param The target type + * @return The converted object of type T */ @Override - public T unmarshalFrom(Object data, TypeReference typeRef) { - return objectMapper.convertValue(data, typeRef); + public T unmarshalFrom(Object data, TypeRef typeRef) { + return jsonMapper.convertValue(data, typeRef); } /** @@ -455,14 +412,12 @@ public T unmarshalFrom(Object data, TypeReference typeRef) { @Override public Mono closeGracefully() { return Mono.fromRunnable(() -> { - logger.debug("Closing session transport: {}", sessionId); sseBuilderLock.lock(); try { sseBuilder.complete(); - logger.debug("Successfully completed SSE builder for session {}", sessionId); } catch (Exception e) { - logger.warn("Failed to complete SSE builder for session {}: {}", sessionId, e.getMessage()); + logger.warn("Failed to complete SSE builder: {}", e.getMessage()); } finally { sseBuilderLock.unlock(); @@ -478,10 +433,9 @@ public void close() { sseBuilderLock.lock(); try { sseBuilder.complete(); - logger.debug("Successfully completed SSE builder for session {}", sessionId); } catch (Exception e) { - logger.warn("Failed to complete SSE builder for session {}: {}", sessionId, e.getMessage()); + logger.warn("Failed to complete SSE builder: {}", e.getMessage()); } finally { sseBuilderLock.unlock(); @@ -507,7 +461,7 @@ public static Builder builder() { */ public static class Builder { - private ObjectMapper objectMapper = new ObjectMapper(); + private McpJsonMapper jsonMapper; private String baseUrl = ""; @@ -517,14 +471,17 @@ public static class Builder { private Duration keepAliveInterval; + private McpTransportContextExtractor contextExtractor = ( + serverRequest) -> McpTransportContext.EMPTY; + /** * Sets the JSON object mapper to use for message serialization/deserialization. - * @param objectMapper The object mapper to use + * @param jsonMapper The object mapper to use * @return This builder instance for method chaining */ - public Builder objectMapper(ObjectMapper objectMapper) { - Assert.notNull(objectMapper, "ObjectMapper must not be null"); - this.objectMapper = objectMapper; + public Builder jsonMapper(McpJsonMapper jsonMapper) { + Assert.notNull(jsonMapper, "McpJsonMapper must not be null"); + this.jsonMapper = jsonMapper; return this; } @@ -576,18 +533,34 @@ public Builder keepAliveInterval(Duration keepAliveInterval) { return this; } + /** + * Sets the context extractor that allows providing the MCP feature + * implementations to inspect HTTP transport level metadata that was present at + * HTTP request processing time. This allows to extract custom headers and other + * useful data for use during execution later on in the process. + * @param contextExtractor The contextExtractor to fill in a + * {@link McpTransportContext}. + * @return this builder instance + * @throws IllegalArgumentException if contextExtractor is null + */ + public Builder contextExtractor(McpTransportContextExtractor contextExtractor) { + Assert.notNull(contextExtractor, "contextExtractor must not be null"); + this.contextExtractor = contextExtractor; + return this; + } + /** * Builds a new instance of WebMvcSseServerTransportProvider with the configured * settings. * @return A new WebMvcSseServerTransportProvider instance - * @throws IllegalStateException if objectMapper or messageEndpoint is not set + * @throws IllegalStateException if jsonMapper or messageEndpoint is not set */ public WebMvcSseServerTransportProvider build() { if (messageEndpoint == null) { throw new IllegalStateException("MessageEndpoint must be set"); } - return new WebMvcSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint, - keepAliveInterval); + return new WebMvcSseServerTransportProvider(jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, + baseUrl, messageEndpoint, sseEndpoint, keepAliveInterval, contextExtractor); } } diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStatelessServerTransport.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStatelessServerTransport.java index fef1920fc..4223084ff 100644 --- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStatelessServerTransport.java +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStatelessServerTransport.java @@ -4,14 +4,13 @@ package io.modelcontextprotocol.server.transport; -import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.json.McpJsonMapper; import io.modelcontextprotocol.server.McpStatelessServerHandler; -import io.modelcontextprotocol.server.DefaultMcpTransportContext; import io.modelcontextprotocol.server.McpTransportContextExtractor; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpStatelessServerTransport; -import io.modelcontextprotocol.server.McpTransportContext; import io.modelcontextprotocol.util.Assert; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -39,7 +38,7 @@ public class WebMvcStatelessServerTransport implements McpStatelessServerTranspo private static final Logger logger = LoggerFactory.getLogger(WebMvcStatelessServerTransport.class); - private final ObjectMapper objectMapper; + private final McpJsonMapper jsonMapper; private final String mcpEndpoint; @@ -51,13 +50,13 @@ public class WebMvcStatelessServerTransport implements McpStatelessServerTranspo private volatile boolean isClosing = false; - private WebMvcStatelessServerTransport(ObjectMapper objectMapper, String mcpEndpoint, + private WebMvcStatelessServerTransport(McpJsonMapper jsonMapper, String mcpEndpoint, McpTransportContextExtractor contextExtractor) { - Assert.notNull(objectMapper, "objectMapper must not be null"); + Assert.notNull(jsonMapper, "jsonMapper must not be null"); Assert.notNull(mcpEndpoint, "mcpEndpoint must not be null"); Assert.notNull(contextExtractor, "contextExtractor must not be null"); - this.objectMapper = objectMapper; + this.jsonMapper = jsonMapper; this.mcpEndpoint = mcpEndpoint; this.contextExtractor = contextExtractor; this.routerFunction = RouterFunctions.route() @@ -101,7 +100,7 @@ private ServerResponse handlePost(ServerRequest request) { return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down"); } - McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext()); + McpTransportContext transportContext = this.contextExtractor.extract(request); List acceptHeaders = request.headers().asHttpHeaders().getAccept(); if (!(acceptHeaders.contains(MediaType.APPLICATION_JSON) @@ -111,7 +110,7 @@ private ServerResponse handlePost(ServerRequest request) { try { String body = request.body(String.class); - McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body); + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(jsonMapper, body); if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest) { try { @@ -172,11 +171,12 @@ public static Builder builder() { */ public static class Builder { - private ObjectMapper objectMapper; + private McpJsonMapper jsonMapper; private String mcpEndpoint = "/mcp"; - private McpTransportContextExtractor contextExtractor = (serverRequest, context) -> context; + private McpTransportContextExtractor contextExtractor = ( + serverRequest) -> McpTransportContext.EMPTY; private Builder() { // used by a static method @@ -185,13 +185,13 @@ private Builder() { /** * Sets the ObjectMapper to use for JSON serialization/deserialization of MCP * messages. - * @param objectMapper The ObjectMapper instance. Must not be null. + * @param jsonMapper The ObjectMapper instance. Must not be null. * @return this builder instance - * @throws IllegalArgumentException if objectMapper is null + * @throws IllegalArgumentException if jsonMapper is null */ - public Builder objectMapper(ObjectMapper objectMapper) { - Assert.notNull(objectMapper, "ObjectMapper must not be null"); - this.objectMapper = objectMapper; + public Builder jsonMapper(McpJsonMapper jsonMapper) { + Assert.notNull(jsonMapper, "ObjectMapper must not be null"); + this.jsonMapper = jsonMapper; return this; } @@ -230,10 +230,9 @@ public Builder contextExtractor(McpTransportContextExtractor cont * @throws IllegalStateException if required parameters are not set */ public WebMvcStatelessServerTransport build() { - Assert.notNull(objectMapper, "ObjectMapper must be set"); Assert.notNull(mcpEndpoint, "Message endpoint must be set"); - - return new WebMvcStatelessServerTransport(objectMapper, mcpEndpoint, contextExtractor); + return new WebMvcStatelessServerTransport(jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, + mcpEndpoint, contextExtractor); } } diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java index fa51a0130..f2a58d4d8 100644 --- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java @@ -10,6 +10,7 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.locks.ReentrantLock; +import io.modelcontextprotocol.json.McpJsonMapper; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.http.HttpStatus; @@ -20,11 +21,9 @@ import org.springframework.web.servlet.function.ServerResponse; import org.springframework.web.servlet.function.ServerResponse.SseBuilder; -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.json.TypeRef; -import io.modelcontextprotocol.server.DefaultMcpTransportContext; -import io.modelcontextprotocol.server.McpTransportContext; +import io.modelcontextprotocol.common.McpTransportContext; import io.modelcontextprotocol.server.McpTransportContextExtractor; import io.modelcontextprotocol.spec.HttpHeaders; import io.modelcontextprotocol.spec.McpError; @@ -83,7 +82,7 @@ public class WebMvcStreamableServerTransportProvider implements McpStreamableSer */ private final boolean disallowDelete; - private final ObjectMapper objectMapper; + private final McpJsonMapper jsonMapper; private final RouterFunction routerFunction; @@ -105,7 +104,7 @@ public class WebMvcStreamableServerTransportProvider implements McpStreamableSer /** * Constructs a new WebMvcStreamableServerTransportProvider instance. - * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + * @param jsonMapper The McpJsonMapper to use for JSON serialization/deserialization * of messages. * @param baseUrl The base URL for the message endpoint, used to construct the full * endpoint URL for clients. @@ -114,14 +113,14 @@ public class WebMvcStreamableServerTransportProvider implements McpStreamableSer * @param disallowDelete Whether to disallow DELETE requests on the endpoint. * @throws IllegalArgumentException if any parameter is null */ - private WebMvcStreamableServerTransportProvider(ObjectMapper objectMapper, String mcpEndpoint, + private WebMvcStreamableServerTransportProvider(McpJsonMapper jsonMapper, String mcpEndpoint, boolean disallowDelete, McpTransportContextExtractor contextExtractor, Duration keepAliveInterval) { - Assert.notNull(objectMapper, "ObjectMapper must not be null"); + Assert.notNull(jsonMapper, "McpJsonMapper must not be null"); Assert.notNull(mcpEndpoint, "MCP endpoint must not be null"); Assert.notNull(contextExtractor, "McpTransportContextExtractor must not be null"); - this.objectMapper = objectMapper; + this.jsonMapper = jsonMapper; this.mcpEndpoint = mcpEndpoint; this.disallowDelete = disallowDelete; this.contextExtractor = contextExtractor; @@ -144,7 +143,8 @@ private WebMvcStreamableServerTransportProvider(ObjectMapper objectMapper, Strin @Override public List protocolVersions() { - return List.of(ProtocolVersions.MCP_2024_11_05, ProtocolVersions.MCP_2025_03_26); + return List.of(ProtocolVersions.MCP_2024_11_05, ProtocolVersions.MCP_2025_03_26, + ProtocolVersions.MCP_2025_06_18); } @Override @@ -238,9 +238,9 @@ private ServerResponse handleGet(ServerRequest request) { return ServerResponse.badRequest().body("Invalid Accept header. Expected TEXT_EVENT_STREAM"); } - McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext()); + McpTransportContext transportContext = this.contextExtractor.extract(request); - if (!request.headers().asHttpHeaders().containsKey(HttpHeaders.MCP_SESSION_ID)) { + if (request.headers().header(HttpHeaders.MCP_SESSION_ID).isEmpty()) { return ServerResponse.badRequest().body("Session ID required in mcp-session-id header"); } @@ -263,7 +263,7 @@ private ServerResponse handleGet(ServerRequest request) { sessionId, sseBuilder); // Check if this is a replay request - if (request.headers().asHttpHeaders().containsKey(HttpHeaders.LAST_EVENT_ID)) { + if (!request.headers().header(HttpHeaders.LAST_EVENT_ID).isEmpty()) { String lastId = request.headers().asHttpHeaders().getFirst(HttpHeaders.LAST_EVENT_ID); try { @@ -322,17 +322,17 @@ private ServerResponse handlePost(ServerRequest request) { .body(new McpError("Invalid Accept headers. Expected TEXT_EVENT_STREAM and APPLICATION_JSON")); } - McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext()); + McpTransportContext transportContext = this.contextExtractor.extract(request); try { String body = request.body(String.class); - McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body); + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(jsonMapper, body); // Handle initialization request if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest && jsonrpcRequest.method().equals(McpSchema.METHOD_INITIALIZE)) { - McpSchema.InitializeRequest initializeRequest = objectMapper.convertValue(jsonrpcRequest.params(), - new TypeReference() { + McpSchema.InitializeRequest initializeRequest = jsonMapper.convertValue(jsonrpcRequest.params(), + new TypeRef() { }); McpStreamableServerSession.McpStreamableServerSessionInit init = this.sessionFactory .startSession(initializeRequest); @@ -354,7 +354,7 @@ private ServerResponse handlePost(ServerRequest request) { } // Handle other messages that require a session - if (!request.headers().asHttpHeaders().containsKey(HttpHeaders.MCP_SESSION_ID)) { + if (request.headers().header(HttpHeaders.MCP_SESSION_ID).isEmpty()) { return ServerResponse.badRequest().body(new McpError("Session ID missing")); } @@ -431,9 +431,9 @@ private ServerResponse handleDelete(ServerRequest request) { return ServerResponse.status(HttpStatus.METHOD_NOT_ALLOWED).build(); } - McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext()); + McpTransportContext transportContext = this.contextExtractor.extract(request); - if (!request.headers().asHttpHeaders().containsKey(HttpHeaders.MCP_SESSION_ID)) { + if (request.headers().header(HttpHeaders.MCP_SESSION_ID).isEmpty()) { return ServerResponse.badRequest().body("Session ID required in mcp-session-id header"); } @@ -517,7 +517,7 @@ public Mono sendMessage(McpSchema.JSONRPCMessage message, String messageId return; } - String jsonText = objectMapper.writeValueAsString(message); + String jsonText = jsonMapper.writeValueAsString(message); this.sseBuilder.id(messageId != null ? messageId : this.sessionId) .event(MESSAGE_EVENT_TYPE) .data(jsonText); @@ -540,15 +540,15 @@ public Mono sendMessage(McpSchema.JSONRPCMessage message, String messageId } /** - * Converts data from one type to another using the configured ObjectMapper. + * Converts data from one type to another using the configured McpJsonMapper. * @param data The source data object to convert * @param typeRef The target type reference * @return The converted object of type T * @param The target type */ @Override - public T unmarshalFrom(Object data, TypeReference typeRef) { - return objectMapper.convertValue(data, typeRef); + public T unmarshalFrom(Object data, TypeRef typeRef) { + return jsonMapper.convertValue(data, typeRef); } /** @@ -598,26 +598,27 @@ public static Builder builder() { */ public static class Builder { - private ObjectMapper objectMapper; + private McpJsonMapper jsonMapper; private String mcpEndpoint = "/mcp"; private boolean disallowDelete = false; - private McpTransportContextExtractor contextExtractor = (serverRequest, context) -> context; + private McpTransportContextExtractor contextExtractor = ( + serverRequest) -> McpTransportContext.EMPTY; private Duration keepAliveInterval; /** - * Sets the ObjectMapper to use for JSON serialization/deserialization of MCP + * Sets the McpJsonMapper to use for JSON serialization/deserialization of MCP * messages. - * @param objectMapper The ObjectMapper instance. Must not be null. + * @param jsonMapper The McpJsonMapper instance. Must not be null. * @return this builder instance - * @throws IllegalArgumentException if objectMapper is null + * @throws IllegalArgumentException if jsonMapper is null */ - public Builder objectMapper(ObjectMapper objectMapper) { - Assert.notNull(objectMapper, "ObjectMapper must not be null"); - this.objectMapper = objectMapper; + public Builder jsonMapper(McpJsonMapper jsonMapper) { + Assert.notNull(jsonMapper, "McpJsonMapper must not be null"); + this.jsonMapper = jsonMapper; return this; } @@ -678,11 +679,10 @@ public Builder keepAliveInterval(Duration keepAliveInterval) { * @throws IllegalStateException if required parameters are not set */ public WebMvcStreamableServerTransportProvider build() { - Assert.notNull(this.objectMapper, "ObjectMapper must be set"); Assert.notNull(this.mcpEndpoint, "MCP endpoint must be set"); - - return new WebMvcStreamableServerTransportProvider(this.objectMapper, this.mcpEndpoint, this.disallowDelete, - this.contextExtractor, this.keepAliveInterval); + return new WebMvcStreamableServerTransportProvider( + jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, mcpEndpoint, disallowDelete, + contextExtractor, keepAliveInterval); } } diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/common/McpTransportContextIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/common/McpTransportContextIntegrationTests.java new file mode 100644 index 000000000..cc9945436 --- /dev/null +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/common/McpTransportContextIntegrationTests.java @@ -0,0 +1,297 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.common; + +import java.util.Map; +import java.util.function.BiFunction; +import java.util.function.Supplier; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.McpSyncClient; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.client.transport.customizer.McpSyncHttpClientRequestCustomizer; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.McpServerFeatures; +import io.modelcontextprotocol.server.McpStatelessServerFeatures; +import io.modelcontextprotocol.server.McpStatelessSyncServer; +import io.modelcontextprotocol.server.McpSyncServer; +import io.modelcontextprotocol.server.McpSyncServerExchange; +import io.modelcontextprotocol.server.McpTransportContextExtractor; +import io.modelcontextprotocol.server.TestUtil; +import io.modelcontextprotocol.server.TomcatTestUtil; +import io.modelcontextprotocol.server.TomcatTestUtil.TomcatServer; +import io.modelcontextprotocol.server.transport.WebMvcSseServerTransportProvider; +import io.modelcontextprotocol.server.transport.WebMvcStatelessServerTransport; +import io.modelcontextprotocol.server.transport.WebMvcStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.McpSchema; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.web.servlet.config.annotation.EnableWebMvc; +import org.springframework.web.servlet.function.RouterFunction; +import org.springframework.web.servlet.function.ServerRequest; +import org.springframework.web.servlet.function.ServerResponse; +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for {@link McpTransportContext} propagation between MCP clients and + * servers using Spring WebMVC transport implementations. + * + *

+ * This test class validates the end-to-end flow of transport context propagation across + * different MCP transport mechanisms in a Spring WebMVC environment. It demonstrates how + * contextual information can be passed from client to server through HTTP headers and + * properly extracted and utilized on the server side. + * + *

Transport Types Tested

+ *
    + *
  • Stateless: Tests context propagation with + * {@link WebMvcStatelessServerTransport} where each request is independent
  • + *
  • Streamable HTTP: Tests context propagation with + * {@link WebMvcStreamableServerTransportProvider} supporting stateful server + * sessions
  • + *
  • Server-Sent Events (SSE): Tests context propagation with + * {@link WebMvcSseServerTransportProvider} for long-lived connections
  • + *
+ * + * @author Daniel Garnier-Moiroux + * @author Christian Tzolov + */ +@Timeout(15) +public class McpTransportContextIntegrationTests { + + private static final int PORT = TestUtil.findAvailablePort(); + + private TomcatServer tomcatServer; + + private static final ThreadLocal CLIENT_SIDE_HEADER_VALUE_HOLDER = new ThreadLocal<>(); + + private static final String HEADER_NAME = "x-test"; + + private final Supplier clientContextProvider = () -> { + var headerValue = CLIENT_SIDE_HEADER_VALUE_HOLDER.get(); + return headerValue != null ? McpTransportContext.create(Map.of("client-side-header-value", headerValue)) + : McpTransportContext.EMPTY; + }; + + private final McpSyncHttpClientRequestCustomizer clientRequestCustomizer = (builder, method, endpoint, body, + context) -> { + var headerValue = context.get("client-side-header-value"); + if (headerValue != null) { + builder.header(HEADER_NAME, headerValue.toString()); + } + }; + + private static final BiFunction statelessHandler = ( + transportContext, + request) -> new McpSchema.CallToolResult(transportContext.get("server-side-header-value").toString(), null); + + private static final BiFunction statefulHandler = ( + exchange, request) -> statelessHandler.apply(exchange.transportContext(), request); + + private static McpTransportContextExtractor serverContextExtractor = (ServerRequest r) -> { + String headerValue = r.servletRequest().getHeader(HEADER_NAME); + return headerValue != null ? McpTransportContext.create(Map.of("server-side-header-value", headerValue)) + : McpTransportContext.EMPTY; + }; + + private final McpSyncClient streamableClient = McpClient + .sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT) + .httpRequestCustomizer(clientRequestCustomizer) + .build()) + .transportContextProvider(clientContextProvider) + .build(); + + private final McpSyncClient sseClient = McpClient + .sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT) + .httpRequestCustomizer(clientRequestCustomizer) + .build()) + .transportContextProvider(clientContextProvider) + .build(); + + private static final McpSchema.Tool tool = McpSchema.Tool.builder() + .name("test-tool") + .description("return the value of the x-test header from call tool request") + .build(); + + @AfterEach + public void after() { + CLIENT_SIDE_HEADER_VALUE_HOLDER.remove(); + if (streamableClient != null) { + streamableClient.closeGracefully(); + } + if (sseClient != null) { + sseClient.closeGracefully(); + } + stopTomcat(); + } + + @Test + void statelessServer() { + startTomcat(TestStatelessConfig.class); + + McpSchema.InitializeResult initResult = streamableClient.initialize(); + assertThat(initResult).isNotNull(); + + CLIENT_SIDE_HEADER_VALUE_HOLDER.set("some important value"); + McpSchema.CallToolResult response = streamableClient + .callTool(new McpSchema.CallToolRequest("test-tool", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + } + + @Test + void streamableServer() { + + startTomcat(TestStreamableHttpConfig.class); + + McpSchema.InitializeResult initResult = streamableClient.initialize(); + assertThat(initResult).isNotNull(); + + CLIENT_SIDE_HEADER_VALUE_HOLDER.set("some important value"); + McpSchema.CallToolResult response = streamableClient + .callTool(new McpSchema.CallToolRequest("test-tool", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + } + + @Test + void sseServer() { + startTomcat(TestSseConfig.class); + + McpSchema.InitializeResult initResult = sseClient.initialize(); + assertThat(initResult).isNotNull(); + + CLIENT_SIDE_HEADER_VALUE_HOLDER.set("some important value"); + McpSchema.CallToolResult response = sseClient.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + } + + private void startTomcat(Class componentClass) { + tomcatServer = TomcatTestUtil.createTomcatServer("", PORT, componentClass); + try { + tomcatServer.tomcat().start(); + assertThat(tomcatServer.tomcat().getServer().getState()).isEqualTo(LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + } + + private void stopTomcat() { + if (tomcatServer != null && tomcatServer.tomcat() != null) { + try { + tomcatServer.tomcat().stop(); + tomcatServer.tomcat().destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + + @Configuration + @EnableWebMvc + static class TestStatelessConfig { + + @Bean + public WebMvcStatelessServerTransport webMvcStatelessServerTransport() { + + return WebMvcStatelessServerTransport.builder().contextExtractor(serverContextExtractor).build(); + } + + @Bean + public RouterFunction routerFunction(WebMvcStatelessServerTransport transportProvider) { + return transportProvider.getRouterFunction(); + } + + @Bean + public McpStatelessSyncServer mcpStatelessServer(WebMvcStatelessServerTransport transportProvider) { + return McpServer.sync(transportProvider) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpStatelessServerFeatures.SyncToolSpecification(tool, statelessHandler)) + .build(); + } + + } + + @Configuration + @EnableWebMvc + static class TestStreamableHttpConfig { + + @Bean + public WebMvcStreamableServerTransportProvider webMvcStreamableServerTransport() { + + return WebMvcStreamableServerTransportProvider.builder().contextExtractor(serverContextExtractor).build(); + } + + @Bean + public RouterFunction routerFunction( + WebMvcStreamableServerTransportProvider transportProvider) { + return transportProvider.getRouterFunction(); + } + + @Bean + public McpSyncServer mcpStreamableServer(WebMvcStreamableServerTransportProvider transportProvider) { + return McpServer.sync(transportProvider) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpServerFeatures.SyncToolSpecification(tool, null, statefulHandler)) + .build(); + } + + } + + @Configuration + @EnableWebMvc + static class TestSseConfig { + + @Bean + public WebMvcSseServerTransportProvider webMvcSseServerTransport() { + + return WebMvcSseServerTransportProvider.builder() + .contextExtractor(serverContextExtractor) + .messageEndpoint("/mcp/message") + .build(); + } + + @Bean + public RouterFunction routerFunction(WebMvcSseServerTransportProvider transportProvider) { + return transportProvider.getRouterFunction(); + } + + @Bean + public McpSyncServer mcpSseServer(WebMvcSseServerTransportProvider transportProvider) { + return McpServer.sync(transportProvider) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpServerFeatures.SyncToolSpecification(tool, null, statefulHandler)) + .build(); + + } + + } + +} diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMcpStreamableAsyncServerTransportTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMcpStreamableAsyncServerTransportTests.java index 66349216d..36aaa27fb 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMcpStreamableAsyncServerTransportTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMcpStreamableAsyncServerTransportTests.java @@ -8,6 +8,7 @@ import org.apache.catalina.LifecycleException; import org.apache.catalina.startup.Tomcat; import org.junit.jupiter.api.Timeout; + import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; @@ -16,14 +17,12 @@ import org.springframework.web.servlet.function.RouterFunction; import org.springframework.web.servlet.function.ServerResponse; -import com.fasterxml.jackson.databind.ObjectMapper; - import io.modelcontextprotocol.server.transport.WebMvcStreamableServerTransportProvider; import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; import reactor.netty.DisposableServer; /** - * Tests for {@link McpAsyncServer} using {@link WebFluxSseServerTransportProvider}. + * Tests for {@link McpAsyncServer} using {@link WebMvcSseServerTransportProvider}. * * @author Christian Tzolov */ @@ -48,10 +47,7 @@ static class TestConfig { @Bean public WebMvcStreamableServerTransportProvider webMvcSseServerTransportProvider() { - return WebMvcStreamableServerTransportProvider.builder() - .objectMapper(new ObjectMapper()) - .mcpEndpoint(MCP_ENDPOINT) - .build(); + return WebMvcStreamableServerTransportProvider.builder().mcpEndpoint(MCP_ENDPOINT).build(); } @Bean diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMcpStreamableSyncServerTransportTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMcpStreamableSyncServerTransportTests.java index cab487f12..2f75551eb 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMcpStreamableSyncServerTransportTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMcpStreamableSyncServerTransportTests.java @@ -16,14 +16,12 @@ import org.springframework.web.servlet.function.RouterFunction; import org.springframework.web.servlet.function.ServerResponse; -import com.fasterxml.jackson.databind.ObjectMapper; - import io.modelcontextprotocol.server.transport.WebMvcStreamableServerTransportProvider; import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; import reactor.netty.DisposableServer; /** - * Tests for {@link McpAsyncServer} using {@link WebFluxSseServerTransportProvider}. + * Tests for {@link McpAsyncServer} using {@link WebMvcSseServerTransportProvider}. * * @author Christian Tzolov */ @@ -48,10 +46,7 @@ static class TestConfig { @Bean public WebMvcStreamableServerTransportProvider webMvcSseServerTransportProvider() { - return WebMvcStreamableServerTransportProvider.builder() - .objectMapper(new ObjectMapper()) - .mcpEndpoint(MCP_ENDPOINT) - .build(); + return WebMvcStreamableServerTransportProvider.builder().mcpEndpoint(MCP_ENDPOINT).build(); } @Bean diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportTests.java index bb4c2bf37..ccf3170c9 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportTests.java @@ -4,7 +4,6 @@ package io.modelcontextprotocol.server; -import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.server.transport.WebMvcSseServerTransportProvider; import io.modelcontextprotocol.spec.McpServerTransportProvider; import org.apache.catalina.Context; @@ -37,7 +36,10 @@ static class TestConfig { @Bean public WebMvcSseServerTransportProvider webMvcSseServerTransportProvider() { - return new WebMvcSseServerTransportProvider(new ObjectMapper(), MESSAGE_ENDPOINT); + return WebMvcSseServerTransportProvider.builder() + .messageEndpoint(MESSAGE_ENDPOINT) + .sseEndpoint(WebMvcSseServerTransportProvider.DEFAULT_SSE_ENDPOINT) + .build(); } @Bean diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomContextPathTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomContextPathTests.java index cce36d191..d8d26af48 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomContextPathTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomContextPathTests.java @@ -3,7 +3,6 @@ */ package io.modelcontextprotocol.server; -import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.client.McpClient; import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; import io.modelcontextprotocol.server.transport.WebMvcSseServerTransportProvider; @@ -92,7 +91,6 @@ static class TestConfig { public WebMvcSseServerTransportProvider webMvcSseServerTransportProvider() { return WebMvcSseServerTransportProvider.builder() - .objectMapper(new ObjectMapper()) .baseUrl(CUSTOM_CONTEXT_PATH) .messageEndpoint(MESSAGE_ENDPOINT) .sseEndpoint(WebMvcSseServerTransportProvider.DEFAULT_SSE_ENDPOINT) diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java index 5d048353c..045f9b3dd 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java @@ -6,25 +6,29 @@ import static org.assertj.core.api.Assertions.assertThat; import java.time.Duration; +import java.util.Map; +import java.util.stream.Stream; import org.apache.catalina.LifecycleException; import org.apache.catalina.LifecycleState; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.provider.Arguments; + import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.web.reactive.function.client.WebClient; import org.springframework.web.servlet.config.annotation.EnableWebMvc; import org.springframework.web.servlet.function.RouterFunction; +import org.springframework.web.servlet.function.ServerRequest; import org.springframework.web.servlet.function.ServerResponse; -import com.fasterxml.jackson.databind.ObjectMapper; - import io.modelcontextprotocol.AbstractMcpClientServerIntegrationTests; import io.modelcontextprotocol.client.McpClient; import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; +import io.modelcontextprotocol.common.McpTransportContext; import io.modelcontextprotocol.server.McpServer.AsyncSpecification; import io.modelcontextprotocol.server.McpServer.SingleSessionSyncSpecification; import io.modelcontextprotocol.server.transport.WebMvcSseServerTransportProvider; @@ -39,6 +43,13 @@ class WebMvcSseIntegrationTests extends AbstractMcpClientServerIntegrationTests private WebMvcSseServerTransportProvider mcpServerTransportProvider; + static McpTransportContextExtractor TEST_CONTEXT_EXTRACTOR = r -> McpTransportContext + .create(Map.of("important", "value")); + + static Stream clientsForTesting() { + return Stream.of(Arguments.of("httpclient"), Arguments.of("webflux")); + } + @Override protected void prepareClients(int port, String mcpEndpoint) { @@ -58,8 +69,8 @@ static class TestConfig { @Bean public WebMvcSseServerTransportProvider webMvcSseServerTransportProvider() { return WebMvcSseServerTransportProvider.builder() - .objectMapper(new ObjectMapper()) .messageEndpoint(MESSAGE_ENDPOINT) + .contextExtractor(TEST_CONTEXT_EXTRACTOR) .build(); } diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseSyncServerTransportTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseSyncServerTransportTests.java index 7e49ddf3b..66d6d3ae9 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseSyncServerTransportTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseSyncServerTransportTests.java @@ -4,7 +4,6 @@ package io.modelcontextprotocol.server; -import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.server.transport.WebMvcSseServerTransportProvider; import org.apache.catalina.Context; import org.apache.catalina.LifecycleException; @@ -36,7 +35,7 @@ static class TestConfig { @Bean public WebMvcSseServerTransportProvider webMvcSseServerTransportProvider() { - return new WebMvcSseServerTransportProvider(new ObjectMapper(), MESSAGE_ENDPOINT); + return WebMvcSseServerTransportProvider.builder().messageEndpoint(MESSAGE_ENDPOINT).build(); } @Bean diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStatelessIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStatelessIntegrationTests.java index c7c1e710d..8c7b0a85e 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStatelessIntegrationTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStatelessIntegrationTests.java @@ -6,12 +6,15 @@ import static org.assertj.core.api.Assertions.assertThat; import java.time.Duration; +import java.util.stream.Stream; import org.apache.catalina.LifecycleException; import org.apache.catalina.LifecycleState; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.provider.Arguments; + import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.web.reactive.function.client.WebClient; @@ -19,8 +22,6 @@ import org.springframework.web.servlet.function.RouterFunction; import org.springframework.web.servlet.function.ServerResponse; -import com.fasterxml.jackson.databind.ObjectMapper; - import io.modelcontextprotocol.AbstractStatelessIntegrationTests; import io.modelcontextprotocol.client.McpClient; import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; @@ -39,6 +40,10 @@ class WebMvcStatelessIntegrationTests extends AbstractStatelessIntegrationTests private WebMvcStatelessServerTransport mcpServerTransport; + static Stream clientsForTesting() { + return Stream.of(Arguments.of("httpclient"), Arguments.of("webflux")); + } + @Configuration @EnableWebMvc static class TestConfig { @@ -46,10 +51,7 @@ static class TestConfig { @Bean public WebMvcStatelessServerTransport webMvcStatelessServerTransport() { - return WebMvcStatelessServerTransport.builder() - .objectMapper(new ObjectMapper()) - .messageEndpoint(MESSAGE_ENDPOINT) - .build(); + return WebMvcStatelessServerTransport.builder().messageEndpoint(MESSAGE_ENDPOINT).build(); } diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStreamableIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStreamableIntegrationTests.java index 16012e7d9..cb7b4a2a0 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStreamableIntegrationTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStreamableIntegrationTests.java @@ -6,25 +6,29 @@ import static org.assertj.core.api.Assertions.assertThat; import java.time.Duration; +import java.util.Map; +import java.util.stream.Stream; import org.apache.catalina.LifecycleException; import org.apache.catalina.LifecycleState; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.provider.Arguments; + import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.web.reactive.function.client.WebClient; +import org.springframework.web.servlet.function.ServerRequest; import org.springframework.web.servlet.config.annotation.EnableWebMvc; import org.springframework.web.servlet.function.RouterFunction; import org.springframework.web.servlet.function.ServerResponse; -import com.fasterxml.jackson.databind.ObjectMapper; - import io.modelcontextprotocol.AbstractMcpClientServerIntegrationTests; import io.modelcontextprotocol.client.McpClient; import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; +import io.modelcontextprotocol.common.McpTransportContext; import io.modelcontextprotocol.server.McpServer.AsyncSpecification; import io.modelcontextprotocol.server.McpServer.SyncSpecification; import io.modelcontextprotocol.server.transport.WebMvcStreamableServerTransportProvider; @@ -39,6 +43,13 @@ class WebMvcStreamableIntegrationTests extends AbstractMcpClientServerIntegratio private WebMvcStreamableServerTransportProvider mcpServerTransportProvider; + static McpTransportContextExtractor TEST_CONTEXT_EXTRACTOR = r -> McpTransportContext + .create(Map.of("important", "value")); + + static Stream clientsForTesting() { + return Stream.of(Arguments.of("httpclient"), Arguments.of("webflux")); + } + @Configuration @EnableWebMvc static class TestConfig { @@ -46,7 +57,7 @@ static class TestConfig { @Bean public WebMvcStreamableServerTransportProvider webMvcStreamableServerTransportProvider() { return WebMvcStreamableServerTransportProvider.builder() - .objectMapper(new ObjectMapper()) + .contextExtractor(TEST_CONTEXT_EXTRACTOR) .mcpEndpoint(MESSAGE_ENDPOINT) .build(); } diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProviderTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProviderTests.java new file mode 100644 index 000000000..1074e8a35 --- /dev/null +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProviderTests.java @@ -0,0 +1,119 @@ +/* + * Copyright 2025 - 2025 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.TestUtil; +import io.modelcontextprotocol.server.TomcatTestUtil; +import io.modelcontextprotocol.spec.McpSchema; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.web.servlet.config.annotation.EnableWebMvc; +import org.springframework.web.servlet.function.RouterFunction; +import org.springframework.web.servlet.function.ServerResponse; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for WebMvcSseServerTransportProvider + * + * @author lance + */ +class WebMvcSseServerTransportProviderTests { + + private static final int PORT = TestUtil.findAvailablePort(); + + private static final String CUSTOM_CONTEXT_PATH = ""; + + private static final String MESSAGE_ENDPOINT = "/mcp/message"; + + private WebMvcSseServerTransportProvider mcpServerTransportProvider; + + McpClient.SyncSpec clientBuilder; + + private TomcatTestUtil.TomcatServer tomcatServer; + + @BeforeEach + public void before() { + tomcatServer = TomcatTestUtil.createTomcatServer(CUSTOM_CONTEXT_PATH, PORT, TestConfig.class); + + try { + tomcatServer.tomcat().start(); + assertThat(tomcatServer.tomcat().getServer().getState()).isEqualTo(LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + + HttpClientSseClientTransport transport = HttpClientSseClientTransport.builder("http://localhost:" + PORT) + .sseEndpoint(WebMvcSseServerTransportProvider.DEFAULT_SSE_ENDPOINT) + .build(); + + clientBuilder = McpClient.sync(transport); + mcpServerTransportProvider = tomcatServer.appContext().getBean(WebMvcSseServerTransportProvider.class); + } + + @Test + void validBaseUrl() { + McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").build(); + try (var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) + .build()) { + assertThat(client.initialize()).isNotNull(); + } + } + + @AfterEach + public void after() { + if (mcpServerTransportProvider != null) { + mcpServerTransportProvider.closeGracefully().block(); + } + if (tomcatServer.appContext() != null) { + tomcatServer.appContext().close(); + } + if (tomcatServer.tomcat() != null) { + try { + tomcatServer.tomcat().stop(); + tomcatServer.tomcat().destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + + @Configuration + @EnableWebMvc + static class TestConfig { + + @Bean + public WebMvcSseServerTransportProvider webMvcSseServerTransportProvider() { + + return WebMvcSseServerTransportProvider.builder() + .baseUrl("http://localhost:" + PORT + "/") + .messageEndpoint(MESSAGE_ENDPOINT) + .sseEndpoint(WebMvcSseServerTransportProvider.DEFAULT_SSE_ENDPOINT) + .jsonMapper(McpJsonMapper.getDefault()) + .contextExtractor(req -> McpTransportContext.EMPTY) + .build(); + } + + @Bean + public RouterFunction routerFunction(WebMvcSseServerTransportProvider transportProvider) { + return transportProvider.getRouterFunction(); + } + + } + +} diff --git a/mcp-test/pom.xml b/mcp-test/pom.xml index 563f60de9..7fc22e5d2 100644 --- a/mcp-test/pom.xml +++ b/mcp-test/pom.xml @@ -6,7 +6,7 @@ io.modelcontextprotocol.sdk mcp-parent - 0.12.0-SNAPSHOT + 0.18.0-SNAPSHOT mcp-test jar @@ -24,7 +24,7 @@ io.modelcontextprotocol.sdk mcp - 0.12.0-SNAPSHOT + 0.18.0-SNAPSHOT diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/AbstractMcpClientServerIntegrationTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/AbstractMcpClientServerIntegrationTests.java index 8e041d91e..270bc4308 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/AbstractMcpClientServerIntegrationTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/AbstractMcpClientServerIntegrationTests.java @@ -4,14 +4,6 @@ package io.modelcontextprotocol; -import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson; -import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatExceptionOfType; -import static org.assertj.core.api.Assertions.assertWith; -import static org.awaitility.Awaitility.await; -import static org.mockito.Mockito.mock; - import java.net.URI; import java.net.http.HttpClient; import java.net.http.HttpRequest; @@ -23,15 +15,14 @@ import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiFunction; import java.util.function.Function; import java.util.stream.Collectors; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; - import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.common.McpTransportContext; import io.modelcontextprotocol.server.McpServer; import io.modelcontextprotocol.server.McpServerFeatures; import io.modelcontextprotocol.server.McpSyncServer; @@ -54,11 +45,25 @@ import io.modelcontextprotocol.spec.McpSchema.Role; import io.modelcontextprotocol.spec.McpSchema.Root; import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.TextContent; import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.util.Utils; import net.javacrumbs.jsonunit.core.Option; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; +import static io.modelcontextprotocol.util.ToolsUtils.EMPTY_JSON_SCHEMA; +import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson; +import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertWith; +import static org.awaitility.Awaitility.await; +import static org.mockito.Mockito.mock; + public abstract class AbstractMcpClientServerIntegrationTests { protected ConcurrentHashMap clientBuilders = new ConcurrentHashMap<>(); @@ -70,7 +75,7 @@ public abstract class AbstractMcpClientServerIntegrationTests { abstract protected McpServer.SyncSpecification prepareSyncServerBuilder(); @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void simple(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -78,7 +83,6 @@ void simple(String clientType) { var server = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .requestTimeout(Duration.ofSeconds(1000)) .build(); - try ( // Create client without sampling capabilities var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) @@ -88,23 +92,25 @@ void simple(String clientType) { assertThat(client.initialize()).isNotNull(); } - server.closeGracefully(); + finally { + server.closeGracefully().block(); + } } // --------------------------------------- // Sampling Tests // --------------------------------------- @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testCreateMessageWithoutSamplingCapabilities(String clientType) { var clientBuilder = clientBuilders.get(clientType); McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) .callHandler((exchange, request) -> { - exchange.createMessage(mock(McpSchema.CreateMessageRequest.class)).block(); - return Mono.just(mock(CallToolResult.class)); + return exchange.createMessage(mock(McpSchema.CreateMessageRequest.class)) + .then(Mono.just(mock(CallToolResult.class))); }) .build(); @@ -125,11 +131,13 @@ void testCreateMessageWithoutSamplingCapabilities(String clientType) { .hasMessage("Client must be configured with sampling capabilities"); } } - server.closeGracefully(); + finally { + server.closeGracefully().block(); + } } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testCreateMessageSuccess(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -142,13 +150,14 @@ void testCreateMessageSuccess(String clientType) { CreateMessageResult.StopReason.STOP_SEQUENCE); }; - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); + CallToolResult callResponse = McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent("CALL RESPONSE")) + .build(); AtomicReference samplingResult = new AtomicReference<>(); McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) .callHandler((exchange, request) -> { var createMessageRequest = McpSchema.CreateMessageRequest.builder() @@ -192,11 +201,13 @@ void testCreateMessageSuccess(String clientType) { assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); }); } - mcpServer.close(); + finally { + mcpServer.closeGracefully().block(); + } } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testCreateMessageWithRequestTimeoutSuccess(String clientType) throws InterruptedException { // Client @@ -216,20 +227,16 @@ void testCreateMessageWithRequestTimeoutSuccess(String clientType) throws Interr CreateMessageResult.StopReason.STOP_SEQUENCE); }; - var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().sampling().build()) - .sampling(samplingHandler) - .build(); - // Server - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); + CallToolResult callResponse = McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent("CALL RESPONSE")) + .build(); AtomicReference samplingResult = new AtomicReference<>(); McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) .callHandler((exchange, request) -> { var createMessageRequest = McpSchema.CreateMessageRequest.builder() @@ -253,30 +260,35 @@ void testCreateMessageWithRequestTimeoutSuccess(String clientType) throws Interr .requestTimeout(Duration.ofSeconds(4)) .tools(tool) .build(); + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); - assertWith(samplingResult.get(), result -> { - assertThat(result).isNotNull(); - assertThat(result.role()).isEqualTo(Role.USER); - assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); - assertThat(result.model()).isEqualTo("MockModelName"); - assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); - }); - - mcpClient.close(); - mcpServer.close(); + assertWith(samplingResult.get(), result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }); + } + finally { + mcpServer.closeGracefully().block(); + } } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testCreateMessageWithRequestTimeoutFail(String clientType) throws InterruptedException { var clientBuilder = clientBuilders.get(clientType); @@ -294,16 +306,12 @@ void testCreateMessageWithRequestTimeoutFail(String clientType) throws Interrupt CreateMessageResult.StopReason.STOP_SEQUENCE); }; - var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().sampling().build()) - .sampling(samplingHandler) + CallToolResult callResponse = McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent("CALL RESPONSE")) .build(); - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) .callHandler((exchange, request) -> { var createMessageRequest = McpSchema.CreateMessageRequest.builder() @@ -326,28 +334,34 @@ void testCreateMessageWithRequestTimeoutFail(String clientType) throws Interrupt .tools(tool) .build(); - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build()) { - assertThatExceptionOfType(McpError.class).isThrownBy(() -> { - mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - }).withMessageContaining("1000ms"); + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); - mcpClient.close(); - mcpServer.close(); + assertThatExceptionOfType(McpError.class).isThrownBy(() -> { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + }).withMessageContaining("1000ms"); + } + finally { + mcpServer.closeGracefully().block(); + } } // --------------------------------------- // Elicitation Tests // --------------------------------------- @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testCreateElicitationWithoutElicitationCapabilities(String clientType) { var clientBuilder = clientBuilders.get(clientType); McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) .callHandler((exchange, request) -> exchange.createElicitation(mock(ElicitRequest.class)) .then(Mono.just(mock(CallToolResult.class)))) .build(); @@ -367,11 +381,13 @@ void testCreateElicitationWithoutElicitationCapabilities(String clientType) { .hasMessage("Client must be configured with elicitation capabilities"); } } - server.closeGracefully().block(); + finally { + server.closeGracefully().block(); + } } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testCreateElicitationSuccess(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -384,11 +400,12 @@ void testCreateElicitationSuccess(String clientType) { Map.of("message", request.message())); }; - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); + CallToolResult callResponse = McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent("CALL RESPONSE")) + .build(); McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) .callHandler((exchange, request) -> { var elicitationRequest = McpSchema.ElicitRequest.builder() @@ -422,11 +439,13 @@ void testCreateElicitationSuccess(String clientType) { assertThat(response).isNotNull(); assertThat(response).isEqualTo(callResponse); } - mcpServer.closeGracefully().block(); + finally { + mcpServer.closeGracefully().block(); + } } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testCreateElicitationWithRequestTimeoutSuccess(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -437,18 +456,14 @@ void testCreateElicitationWithRequestTimeoutSuccess(String clientType) { return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); }; - var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().elicitation().build()) - .elicitation(elicitationHandler) + CallToolResult callResponse = McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent("CALL RESPONSE")) .build(); - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); - AtomicReference resultRef = new AtomicReference<>(); McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) .callHandler((exchange, request) -> { var elicitationRequest = McpSchema.ElicitRequest.builder() @@ -468,25 +483,31 @@ void testCreateElicitationWithRequestTimeoutSuccess(String clientType) { .tools(tool) .build(); - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build()) { - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); - assertWith(resultRef.get(), result -> { - assertThat(result).isNotNull(); - assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); - assertThat(result.content().get("message")).isEqualTo("Test message"); - }); + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - mcpClient.closeGracefully(); - mcpServer.closeGracefully().block(); + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + assertWith(resultRef.get(), result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }); + } + finally { + mcpServer.closeGracefully().block(); + } } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testCreateElicitationWithRequestTimeoutFail(String clientType) { var latch = new CountDownLatch(1); @@ -508,17 +529,12 @@ void testCreateElicitationWithRequestTimeoutFail(String clientType) { return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); }; - var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().elicitation().build()) - .elicitation(elicitationHandler) - .build(); - - CallToolResult callResponse = new CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + CallToolResult callResponse = CallToolResult.builder().addContent(new TextContent("CALL RESPONSE")).build(); AtomicReference resultRef = new AtomicReference<>(); McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) .callHandler((exchange, request) -> { var elicitationRequest = ElicitRequest.builder() @@ -538,25 +554,31 @@ void testCreateElicitationWithRequestTimeoutFail(String clientType) { .tools(tool) .build(); - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build()) { - assertThatExceptionOfType(McpError.class).isThrownBy(() -> { - mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - }).withMessageContaining("within 1000ms"); + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); - ElicitResult elicitResult = resultRef.get(); - assertThat(elicitResult).isNull(); + assertThatExceptionOfType(McpError.class).isThrownBy(() -> { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + }).withMessageContaining("within 1000ms"); - mcpClient.closeGracefully(); - mcpServer.closeGracefully().block(); + ElicitResult elicitResult = resultRef.get(); + assertThat(elicitResult).isNull(); + } + finally { + mcpServer.closeGracefully().block(); + } } // --------------------------------------- // Roots Tests // --------------------------------------- @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testRootsSuccess(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -598,18 +620,19 @@ void testRootsSuccess(String clientType) { assertThat(rootsRef.get()).containsAll(List.of(roots.get(1), root3)); }); } - - mcpServer.close(); + finally { + mcpServer.closeGracefully(); + } } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testRootsWithoutCapability(String clientType) { var clientBuilder = clientBuilders.get(clientType); McpServerFeatures.SyncToolSpecification tool = McpServerFeatures.SyncToolSpecification.builder() - .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) .callHandler((exchange, request) -> { exchange.listRoots(); // try to list roots @@ -636,12 +659,13 @@ void testRootsWithoutCapability(String clientType) { assertThat(e).isInstanceOf(McpError.class).hasMessage("Roots not supported"); } } - - mcpServer.close(); + finally { + mcpServer.closeGracefully(); + } } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testRootsNotificationWithEmptyRootsList(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -665,12 +689,13 @@ void testRootsNotificationWithEmptyRootsList(String clientType) { assertThat(rootsRef.get()).isEmpty(); }); } - - mcpServer.close(); + finally { + mcpServer.closeGracefully(); + } } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testRootsWithMultipleHandlers(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -698,12 +723,13 @@ void testRootsWithMultipleHandlers(String clientType) { assertThat(rootsRef2.get()).containsAll(roots); }); } - - mcpServer.close(); + finally { + mcpServer.closeGracefully(); + } } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testRootsServerCloseWithActiveSubscription(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -729,30 +755,26 @@ void testRootsServerCloseWithActiveSubscription(String clientType) { assertThat(rootsRef.get()).containsAll(roots); }); } - - mcpServer.close(); + finally { + mcpServer.closeGracefully(); + } } // --------------------------------------- // Tools Tests // --------------------------------------- - String emptyJsonSchema = """ - { - "$schema": "http://json-schema.org/draft-07/schema#", - "type": "object", - "properties": {} - } - """; - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testToolCallSuccess(String clientType) { var clientBuilder = clientBuilders.get(clientType); - var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + var responseBodyIsNullOrBlank = new AtomicBoolean(false); + var callResponse = McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent("CALL RESPONSE; ctx=importantValue")) + .build(); McpServerFeatures.SyncToolSpecification tool1 = McpServerFeatures.SyncToolSpecification.builder() - .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) .callHandler((exchange, request) -> { try { @@ -763,7 +785,7 @@ void testToolCallSuccess(String clientType) { .GET() .build(), HttpResponse.BodyHandlers.ofString()); String responseBody = response.body(); - assertThat(responseBody).isNotBlank(); + responseBodyIsNullOrBlank.set(!Utils.hasText(responseBody)); } catch (Exception e) { e.printStackTrace(); @@ -786,14 +808,16 @@ void testToolCallSuccess(String clientType) { CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + assertThat(responseBodyIsNullOrBlank.get()).isFalse(); assertThat(response).isNotNull().isEqualTo(callResponse); } - - mcpServer.close(); + finally { + mcpServer.closeGracefully(); + } } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testThrowingToolCallIsCaughtBeforeTimeout(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -804,7 +828,7 @@ void testThrowingToolCallIsCaughtBeforeTimeout(String clientType) { .tool(Tool.builder() .name("tool1") .description("tool1 description") - .inputSchema(emptyJsonSchema) + .inputSchema(EMPTY_JSON_SCHEMA) .build()) .callHandler((exchange, request) -> { // We trigger a timeout on blocking read, raising an exception @@ -819,25 +843,87 @@ void testThrowingToolCallIsCaughtBeforeTimeout(String clientType) { assertThat(initResult).isNotNull(); // We expect the tool call to fail immediately with the exception raised by - // the offending tool - // instead of getting back a timeout. + // the offending tool instead of getting back a timeout. assertThatExceptionOfType(McpError.class) .isThrownBy(() -> mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of()))) .withMessageContaining("Timeout on blocking read"); } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testToolCallSuccessWithTranportContextExtraction(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var transportContextIsNull = new AtomicBoolean(false); + var transportContextIsEmpty = new AtomicBoolean(false); + var responseBodyIsNullOrBlank = new AtomicBoolean(false); + + var expectedCallResponse = McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent("CALL RESPONSE; ctx=value")) + .build(); + McpServerFeatures.SyncToolSpecification tool1 = McpServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) + .callHandler((exchange, request) -> { + + McpTransportContext transportContext = exchange.transportContext(); + transportContextIsNull.set(transportContext == null); + transportContextIsEmpty.set(transportContext.equals(McpTransportContext.EMPTY)); + String ctxValue = (String) transportContext.get("important"); + + try { + String responseBody = "TOOL RESPONSE"; + responseBodyIsNullOrBlank.set(!Utils.hasText(responseBody)); + } + catch (Exception e) { + e.printStackTrace(); + } + + return McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent("CALL RESPONSE; ctx=" + ctxValue)) + .build(); + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + try (var mcpClient = clientBuilder.build()) { - mcpServer.close(); + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(transportContextIsNull.get()).isFalse(); + assertThat(transportContextIsEmpty.get()).isFalse(); + assertThat(responseBodyIsNullOrBlank.get()).isFalse(); + assertThat(response).isNotNull().isEqualTo(expectedCallResponse); + } + finally { + mcpServer.closeGracefully(); + } } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testToolListChangeHandlingSuccess(String clientType) { var clientBuilder = clientBuilders.get(clientType); - var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + var callResponse = McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent("CALL RESPONSE")) + .build(); + McpServerFeatures.SyncToolSpecification tool1 = McpServerFeatures.SyncToolSpecification.builder() - .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) .callHandler((exchange, request) -> { // perform a blocking call to a remote service try { @@ -906,7 +992,7 @@ void testToolListChangeHandlingSuccess(String clientType) { .tool(Tool.builder() .name("tool2") .description("tool2 description") - .inputSchema(emptyJsonSchema) + .inputSchema(EMPTY_JSON_SCHEMA) .build()) .callHandler((exchange, request) -> callResponse) .build(); @@ -917,12 +1003,13 @@ void testToolListChangeHandlingSuccess(String clientType) { assertThat(toolsRef.get()).containsAll(List.of(tool2.tool())); }); } - - mcpServer.close(); + finally { + mcpServer.closeGracefully(); + } } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testInitialize(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -934,15 +1021,16 @@ void testInitialize(String clientType) { InitializeResult initResult = mcpClient.initialize(); assertThat(initResult).isNotNull(); } - - mcpServer.close(); + finally { + mcpServer.closeGracefully(); + } } // --------------------------------------- // Logging Tests // --------------------------------------- @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testLoggingNotification(String clientType) throws InterruptedException { int expectedNotificationsCount = 3; CountDownLatch latch = new CountDownLatch(expectedNotificationsCount); @@ -956,7 +1044,7 @@ void testLoggingNotification(String clientType) throws InterruptedException { .tool(Tool.builder() .name("logging-test") .description("Test logging notifications") - .inputSchema(emptyJsonSchema) + .inputSchema(EMPTY_JSON_SCHEMA) .build()) .callHandler((exchange, request) -> { @@ -993,7 +1081,10 @@ void testLoggingNotification(String clientType) throws InterruptedException { .logger("test-logger") .data("Another error message") .build())) - .thenReturn(new CallToolResult("Logging test completed", false)); + .thenReturn(CallToolResult.builder() + .content(List.of(new McpSchema.TextContent("Logging test completed"))) + .isError(false) + .build()); //@formatter:on }) .build(); @@ -1046,14 +1137,16 @@ void testLoggingNotification(String clientType) throws InterruptedException { assertThat(notificationMap.get("Another error message").logger()).isEqualTo("test-logger"); assertThat(notificationMap.get("Another error message").data()).isEqualTo("Another error message"); } - mcpServer.close(); + finally { + mcpServer.closeGracefully().block(); + } } // --------------------------------------- // Progress Tests // --------------------------------------- @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testProgressNotification(String clientType) throws InterruptedException { int expectedNotificationsCount = 4; // 3 notifications + 1 for another progress // token @@ -1068,7 +1161,7 @@ void testProgressNotification(String clientType) throws InterruptedException { .tool(McpSchema.Tool.builder() .name("progress-test") .description("Test progress notifications") - .inputSchema(emptyJsonSchema) + .inputSchema(EMPTY_JSON_SCHEMA) .build()) .callHandler((exchange, request) -> { @@ -1086,7 +1179,10 @@ void testProgressNotification(String clientType) throws InterruptedException { 0.0, 1.0, "Another processing started"))) .then(exchange.progressNotification( new McpSchema.ProgressNotification(progressToken, 1.0, 1.0, "Processing completed"))) - .thenReturn(new CallToolResult(("Progress test completed"), false)); + .thenReturn(CallToolResult.builder() + .content(List.of(new McpSchema.TextContent("Progress test completed"))) + .isError(false) + .build()); }) .build(); @@ -1151,7 +1247,7 @@ void testProgressNotification(String clientType) throws InterruptedException { assertThat(notificationMap.get("Processing completed").message()).isEqualTo("Processing completed"); } finally { - mcpServer.close(); + mcpServer.closeGracefully().block(); } } @@ -1159,7 +1255,7 @@ void testProgressNotification(String clientType) throws InterruptedException { // Completion Tests // --------------------------------------- @ParameterizedTest(name = "{0} : Completion call") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testCompletionShouldReturnExpectedSuggestions(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -1181,7 +1277,8 @@ void testCompletionShouldReturnExpectedSuggestions(String clientType) { List.of(new PromptArgument("language", "Language", "string", false))), (mcpSyncServerExchange, getPromptRequest) -> null)) .completions(new McpServerFeatures.SyncCompletionSpecification( - new McpSchema.PromptReference("ref/prompt", "code_review", "Code review"), completionHandler)) + new McpSchema.PromptReference(PromptReference.TYPE, "code_review", "Code review"), + completionHandler)) .build(); try (var mcpClient = clientBuilder.build()) { @@ -1190,7 +1287,7 @@ void testCompletionShouldReturnExpectedSuggestions(String clientType) { assertThat(initResult).isNotNull(); CompleteRequest request = new CompleteRequest( - new PromptReference("ref/prompt", "code_review", "Code review"), + new PromptReference(PromptReference.TYPE, "code_review", "Code review"), new CompleteRequest.CompleteArgument("language", "py")); CompleteResult result = mcpClient.completeCompletion(request); @@ -1199,17 +1296,18 @@ void testCompletionShouldReturnExpectedSuggestions(String clientType) { assertThat(samplingRequest.get().argument().name()).isEqualTo("language"); assertThat(samplingRequest.get().argument().value()).isEqualTo("py"); - assertThat(samplingRequest.get().ref().type()).isEqualTo("ref/prompt"); + assertThat(samplingRequest.get().ref().type()).isEqualTo(PromptReference.TYPE); + } + finally { + mcpServer.closeGracefully(); } - - mcpServer.close(); } // --------------------------------------- // Ping Tests // --------------------------------------- @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testPingSuccess(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -1221,7 +1319,7 @@ void testPingSuccess(String clientType) { .tool(Tool.builder() .name("ping-async-test") .description("Test ping async behavior") - .inputSchema(emptyJsonSchema) + .inputSchema(EMPTY_JSON_SCHEMA) .build()) .callHandler((exchange, request) -> { @@ -1238,7 +1336,10 @@ void testPingSuccess(String clientType) { assertThat(result).isNotNull(); }).then(Mono.fromCallable(() -> { executionOrder.set(executionOrder.get() + "3"); - return new CallToolResult("Async ping test completed", false); + return CallToolResult.builder() + .content(List.of(new McpSchema.TextContent("Async ping test completed"))) + .isError(false) + .build(); })); }) .build(); @@ -1263,15 +1364,16 @@ void testPingSuccess(String clientType) { // Verify execution order assertThat(executionOrder.get()).isEqualTo("123"); } - - mcpServer.close(); + finally { + mcpServer.closeGracefully().block(); + } } // --------------------------------------- // Tool Structured Output Schema Tests // --------------------------------------- @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testStructuredOutputValidationSuccess(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -1323,7 +1425,7 @@ void testStructuredOutputValidationSuccess(String clientType) { // In WebMVC, structured content is returned properly if (response.structuredContent() != null) { - assertThat(response.structuredContent()).containsEntry("result", 5.0) + assertThat((Map) response.structuredContent()).containsEntry("result", 5.0) .containsEntry("operation", "2 + 3") .containsEntry("timestamp", "2024-01-01T10:00:00Z"); } @@ -1339,12 +1441,129 @@ void testStructuredOutputValidationSuccess(String clientType) { .isEqualTo(json(""" {"result":5.0,"operation":"2 + 3","timestamp":"2024-01-01T10:00:00Z"}""")); } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testStructuredOutputOfObjectArrayValidationSuccess(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema that returns an array of objects + Map outputSchema = Map + .of( // @formatter:off + "type", "array", + "items", Map.of( + "type", "object", + "properties", Map.of( + "name", Map.of("type", "string"), + "age", Map.of("type", "number")), + "required", List.of("name", "age"))); // @formatter:on + + Tool calculatorTool = Tool.builder() + .name("getMembers") + .description("Returns a list of members") + .outputSchema(outputSchema) + .build(); + + McpServerFeatures.SyncToolSpecification tool = McpServerFeatures.SyncToolSpecification.builder() + .tool(calculatorTool) + .callHandler((exchange, request) -> { + return CallToolResult.builder() + .structuredContent(List.of(Map.of("name", "John", "age", 30), Map.of("name", "Peter", "age", 25))) + .build(); + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + assertThat(mcpClient.initialize()).isNotNull(); + + // Call tool with valid structured output of type array + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("getMembers", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isFalse(); + + assertThat(response.structuredContent()).isNotNull(); + assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isArray() + .hasSize(2) + .containsExactlyInAnyOrder(json(""" + {"name":"John","age":30}"""), json(""" + {"name":"Peter","age":25}""")); + } + finally { + mcpServer.closeGracefully(); + } + } - mcpServer.close(); + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testStructuredOutputWithInHandlerError(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of( + "type", "object", "properties", Map.of("result", Map.of("type", "number"), "operation", + Map.of("type", "string"), "timestamp", Map.of("type", "string")), + "required", List.of("result", "operation")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + // Handler that returns an error result + McpServerFeatures.SyncToolSpecification tool = McpServerFeatures.SyncToolSpecification.builder() + .tool(calculatorTool) + .callHandler((exchange, request) -> CallToolResult.builder() + .isError(true) + .content(List.of(new TextContent("Error calling tool: Simulated in-handler error"))) + .build()) + .build(); + + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Verify tool is listed with output schema + var toolsList = mcpClient.listTools(); + assertThat(toolsList.tools()).hasSize(1); + assertThat(toolsList.tools().get(0).name()).isEqualTo("calculator"); + // Note: outputSchema might be null in sync server, but validation still works + + // Call tool with valid structured output + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isTrue(); + assertThat(response.content()).isNotEmpty(); + assertThat(response.content()) + .containsExactly(new McpSchema.TextContent("Error calling tool: Simulated in-handler error")); + assertThat(response.structuredContent()).isNull(); + } + finally { + mcpServer.closeGracefully(); + } } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @ValueSource(strings = { "httpclient" }) void testStructuredOutputValidationFailure(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -1393,12 +1612,13 @@ void testStructuredOutputValidationFailure(String clientType) { String errorMessage = ((McpSchema.TextContent) response.content().get(0)).text(); assertThat(errorMessage).contains("Validation failed"); } - - mcpServer.close(); + finally { + mcpServer.closeGracefully(); + } } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testStructuredOutputMissingStructuredContent(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -1443,12 +1663,13 @@ void testStructuredOutputMissingStructuredContent(String clientType) { assertThat(errorMessage).isEqualTo( "Response missing structured content which is expected when calling tool with non-empty outputSchema"); } - - mcpServer.close(); + finally { + mcpServer.closeGracefully(); + } } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testStructuredOutputRuntimeToolAddition(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -1520,8 +1741,9 @@ void testStructuredOutputRuntimeToolAddition(String clientType) { .isEqualTo(json(""" {"count":3,"message":"Dynamic execution"}""")); } - - mcpServer.close(); + finally { + mcpServer.closeGracefully(); + } } private double evaluateExpression(String expression) { diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/AbstractStatelessIntegrationTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/AbstractStatelessIntegrationTests.java index 618247d61..240732ebe 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/AbstractStatelessIntegrationTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/AbstractStatelessIntegrationTests.java @@ -4,12 +4,6 @@ package io.modelcontextprotocol; -import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson; -import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatExceptionOfType; -import static org.awaitility.Awaitility.await; - import java.net.URI; import java.net.http.HttpClient; import java.net.http.HttpRequest; @@ -20,9 +14,6 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicReference; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; - import io.modelcontextprotocol.client.McpClient; import io.modelcontextprotocol.server.McpServer.StatelessAsyncSpecification; import io.modelcontextprotocol.server.McpServer.StatelessSyncSpecification; @@ -33,10 +24,21 @@ import io.modelcontextprotocol.spec.McpSchema.CallToolResult; import io.modelcontextprotocol.spec.McpSchema.InitializeResult; import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.TextContent; import io.modelcontextprotocol.spec.McpSchema.Tool; import net.javacrumbs.jsonunit.core.Option; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; import reactor.core.publisher.Mono; +import static io.modelcontextprotocol.util.ToolsUtils.EMPTY_JSON_SCHEMA; +import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson; +import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.awaitility.Awaitility.await; + public abstract class AbstractStatelessIntegrationTests { protected ConcurrentHashMap clientBuilders = new ConcurrentHashMap<>(); @@ -48,7 +50,7 @@ public abstract class AbstractStatelessIntegrationTests { abstract protected StatelessSyncSpecification prepareSyncServerBuilder(); @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void simple(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -66,23 +68,16 @@ void simple(String clientType) { assertThat(client.initialize()).isNotNull(); } - server.closeGracefully(); + finally { + server.closeGracefully().block(); + } } // --------------------------------------- // Tools Tests // --------------------------------------- - - String emptyJsonSchema = """ - { - "$schema": "http://json-schema.org/draft-07/schema#", - "type": "object", - "properties": {} - } - """; - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testToolCallSuccess(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -90,7 +85,7 @@ void testToolCallSuccess(String clientType) { var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); McpStatelessServerFeatures.SyncToolSpecification tool1 = McpStatelessServerFeatures.SyncToolSpecification .builder() - .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) .callHandler((ctx, request) -> { try { @@ -126,12 +121,13 @@ void testToolCallSuccess(String clientType) { assertThat(response).isNotNull().isEqualTo(callResponse); } - - mcpServer.close(); + finally { + mcpServer.closeGracefully().block(); + } } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testThrowingToolCallIsCaughtBeforeTimeout(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -142,7 +138,7 @@ void testThrowingToolCallIsCaughtBeforeTimeout(String clientType) { .tool(Tool.builder() .name("tool1") .description("tool1 description") - .inputSchema(emptyJsonSchema) + .inputSchema(EMPTY_JSON_SCHEMA) .build()) .callHandler((context, request) -> { // We trigger a timeout on blocking read, raising an exception @@ -163,12 +159,13 @@ void testThrowingToolCallIsCaughtBeforeTimeout(String clientType) { .isThrownBy(() -> mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of()))) .withMessageContaining("Timeout on blocking read"); } - - mcpServer.close(); + finally { + mcpServer.closeGracefully(); + } } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testToolListChangeHandlingSuccess(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -176,7 +173,7 @@ void testToolListChangeHandlingSuccess(String clientType) { var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); McpStatelessServerFeatures.SyncToolSpecification tool1 = McpStatelessServerFeatures.SyncToolSpecification .builder() - .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) .callHandler((ctx, request) -> { // perform a blocking call to a remote service try { @@ -237,19 +234,20 @@ void testToolListChangeHandlingSuccess(String clientType) { .tool(Tool.builder() .name("tool2") .description("tool2 description") - .inputSchema(emptyJsonSchema) + .inputSchema(EMPTY_JSON_SCHEMA) .build()) .callHandler((exchange, request) -> callResponse) .build(); mcpServer.addTool(tool2); } - - mcpServer.close(); + finally { + mcpServer.closeGracefully(); + } } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testInitialize(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -261,16 +259,16 @@ void testInitialize(String clientType) { InitializeResult initResult = mcpClient.initialize(); assertThat(initResult).isNotNull(); } - - mcpServer.close(); + finally { + mcpServer.closeGracefully(); + } } // --------------------------------------- // Tool Structured Output Schema Tests // --------------------------------------- - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testStructuredOutputValidationSuccess(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -323,7 +321,7 @@ void testStructuredOutputValidationSuccess(String clientType) { // In WebMVC, structured content is returned properly if (response.structuredContent() != null) { - assertThat(response.structuredContent()).containsEntry("result", 5.0) + assertThat((Map) response.structuredContent()).containsEntry("result", 5.0) .containsEntry("operation", "2 + 3") .containsEntry("timestamp", "2024-01-01T10:00:00Z"); } @@ -339,12 +337,131 @@ void testStructuredOutputValidationSuccess(String clientType) { .isEqualTo(json(""" {"result":5.0,"operation":"2 + 3","timestamp":"2024-01-01T10:00:00Z"}""")); } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testStructuredOutputOfObjectArrayValidationSuccess(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema that returns an array of objects + Map outputSchema = Map + .of( // @formatter:off + "type", "array", + "items", Map.of( + "type", "object", + "properties", Map.of( + "name", Map.of("type", "string"), + "age", Map.of("type", "number")), + "required", List.of("name", "age"))); // @formatter:on - mcpServer.close(); + Tool calculatorTool = Tool.builder() + .name("getMembers") + .description("Returns a list of members") + .outputSchema(outputSchema) + .build(); + + McpStatelessServerFeatures.SyncToolSpecification tool = McpStatelessServerFeatures.SyncToolSpecification + .builder() + .tool(calculatorTool) + .callHandler((exchange, request) -> { + return CallToolResult.builder() + .structuredContent(List.of(Map.of("name", "John", "age", 30), Map.of("name", "Peter", "age", 25))) + .build(); + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + assertThat(mcpClient.initialize()).isNotNull(); + + // Call tool with valid structured output of type array + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("getMembers", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isFalse(); + + assertThat(response.structuredContent()).isNotNull(); + assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isArray() + .hasSize(2) + .containsExactlyInAnyOrder(json(""" + {"name":"John","age":30}"""), json(""" + {"name":"Peter","age":25}""")); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testStructuredOutputWithInHandlerError(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of( + "type", "object", "properties", Map.of("result", Map.of("type", "number"), "operation", + Map.of("type", "string"), "timestamp", Map.of("type", "string")), + "required", List.of("result", "operation")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + // Handler that throws an exception to simulate an error + McpStatelessServerFeatures.SyncToolSpecification tool = McpStatelessServerFeatures.SyncToolSpecification + .builder() + .tool(calculatorTool) + .callHandler((exchange, request) -> CallToolResult.builder() + .isError(true) + .content(List.of(new TextContent("Error calling tool: Simulated in-handler error"))) + .build()) + .build(); + + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Verify tool is listed with output schema + var toolsList = mcpClient.listTools(); + assertThat(toolsList.tools()).hasSize(1); + assertThat(toolsList.tools().get(0).name()).isEqualTo("calculator"); + // Note: outputSchema might be null in sync server, but validation still works + + // Call tool with valid structured output + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isTrue(); + assertThat(response.content()).isNotEmpty(); + assertThat(response.content()) + .containsExactly(new McpSchema.TextContent("Error calling tool: Simulated in-handler error")); + assertThat(response.structuredContent()).isNull(); + } + finally { + mcpServer.closeGracefully(); + } } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testStructuredOutputValidationFailure(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -394,12 +511,13 @@ void testStructuredOutputValidationFailure(String clientType) { String errorMessage = ((McpSchema.TextContent) response.content().get(0)).text(); assertThat(errorMessage).contains("Validation failed"); } - - mcpServer.close(); + finally { + mcpServer.closeGracefully(); + } } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testStructuredOutputMissingStructuredContent(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -444,12 +562,13 @@ void testStructuredOutputMissingStructuredContent(String clientType) { assertThat(errorMessage).isEqualTo( "Response missing structured content which is expected when calling tool with non-empty outputSchema"); } - - mcpServer.close(); + finally { + mcpServer.closeGracefully(); + } } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testStructuredOutputRuntimeToolAddition(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -521,8 +640,9 @@ void testStructuredOutputRuntimeToolAddition(String clientType) { .isEqualTo(json(""" {"count":3,"message":"Dynamic execution"}""")); } - - mcpServer.close(); + finally { + mcpServer.closeGracefully(); + } } private double evaluateExpression(String expression) { diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/MockMcpTransport.java b/mcp-test/src/main/java/io/modelcontextprotocol/MockMcpTransport.java index 5484a63c2..cd8458311 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/MockMcpTransport.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/MockMcpTransport.java @@ -9,8 +9,8 @@ import java.util.function.BiConsumer; import java.util.function.Function; -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.TypeRef; import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.JSONRPCNotification; @@ -93,8 +93,8 @@ public Mono closeGracefully() { } @Override - public T unmarshalFrom(Object data, TypeReference typeRef) { - return new ObjectMapper().convertValue(data, typeRef); + public T unmarshalFrom(Object data, TypeRef typeRef) { + return McpJsonMapper.getDefault().convertValue(data, typeRef); } } diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java index ed34ebff6..d0b1c46a2 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java @@ -10,6 +10,7 @@ import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpTransport; +import io.modelcontextprotocol.spec.McpTransportSessionClosedException; import org.junit.jupiter.api.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -48,7 +49,7 @@ public abstract class AbstractMcpAsyncClientResiliencyTests { // Uses the https://github.com/tzolov/mcp-everything-server-docker-image @SuppressWarnings("resource") - static GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") + static GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v3") .withCommand("node dist/index.js streamableHttp") .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) .withNetwork(network) @@ -134,10 +135,13 @@ McpAsyncClient client(McpClientTransport transport, Function client = new AtomicReference<>(); assertThatCode(() -> { + // Do not advertise roots. Otherwise, the server will list roots during + // initialization. The client responds asynchronously, and there might be a + // rest condition in tests where we disconnect right after initialization. McpClient.AsyncSpec builder = McpClient.async(transport) .requestTimeout(getRequestTimeout()) .initializationTimeout(getInitializationTimeout()) - .capabilities(McpSchema.ClientCapabilities.builder().roots(true).build()); + .capabilities(McpSchema.ClientCapabilities.builder().build()); builder = customizer.apply(builder); client.set(builder.build()); }).doesNotThrowAnyException(); @@ -217,9 +221,10 @@ void testSessionClose() { // In case of Streamable HTTP this call should issue a HTTP DELETE request // invalidating the session StepVerifier.create(mcpAsyncClient.closeGracefully()).expectComplete().verify(); - // The next use should immediately re-initialize with no issue and send the - // request without any broken connections. - StepVerifier.create(mcpAsyncClient.ping()).expectNextCount(1).verifyComplete(); + // The next tries to use the closed session and fails + StepVerifier.create(mcpAsyncClient.ping()) + .expectErrorMatches(err -> err.getCause() instanceof McpTransportSessionClosedException) + .verify(); }); } diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java index ea3739da5..e1b051204 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java @@ -22,8 +22,6 @@ import java.util.function.Consumer; import java.util.function.Function; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; @@ -54,6 +52,8 @@ import reactor.core.publisher.Sinks; import reactor.test.StepVerifier; +import static io.modelcontextprotocol.util.McpJsonMapperUtils.JSON_MAPPER; + /** * Test suite for the {@link McpAsyncClient} that can be used with different * {@link McpTransport} implementations. @@ -67,12 +67,6 @@ public abstract class AbstractMcpAsyncClientTests { abstract protected McpClientTransport createMcpTransport(); - protected void onStart() { - } - - protected void onClose() { - } - protected Duration getRequestTimeout() { return Duration.ofSeconds(14); } @@ -117,16 +111,6 @@ void withClient(McpClientTransport transport, Function void verifyNotificationSucceedsWithImplicitInitialization(Function> operation, String action) { withClient(createMcpTransport(), mcpAsyncClient -> { @@ -192,7 +176,12 @@ void testListAllToolsReturnsImmutableList() { .consumeNextWith(result -> { assertThat(result.tools()).isNotNull(); // Verify that the returned list is immutable - assertThatThrownBy(() -> result.tools().add(new Tool("test", "test", "{\"type\":\"object\"}"))) + assertThatThrownBy(() -> result.tools() + .add(Tool.builder() + .name("test") + .title("test") + .inputSchema(JSON_MAPPER, "{\"type\":\"object\"}") + .build())) .isInstanceOf(UnsupportedOperationException.class); }) .verifyComplete(); diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java index 175a0107c..21e0c1492 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java @@ -22,8 +22,6 @@ import java.util.function.Consumer; import java.util.function.Function; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; @@ -66,12 +64,6 @@ public abstract class AbstractMcpSyncClientTests { abstract protected McpClientTransport createMcpTransport(); - protected void onStart() { - } - - protected void onClose() { - } - protected Duration getRequestTimeout() { return Duration.ofSeconds(14); } @@ -114,17 +106,6 @@ void withClient(McpClientTransport transport, Function void verifyNotificationSucceedsWithImplicitInitialization(Consumer operation, String action) { @@ -554,11 +535,13 @@ void testNotificationHandlers() { AtomicBoolean toolsNotificationReceived = new AtomicBoolean(false); AtomicBoolean resourcesNotificationReceived = new AtomicBoolean(false); AtomicBoolean promptsNotificationReceived = new AtomicBoolean(false); + AtomicBoolean resourcesUpdatedNotificationReceived = new AtomicBoolean(false); withClient(createMcpTransport(), builder -> builder.toolsChangeConsumer(tools -> toolsNotificationReceived.set(true)) .resourcesChangeConsumer(resources -> resourcesNotificationReceived.set(true)) - .promptsChangeConsumer(prompts -> promptsNotificationReceived.set(true)), + .promptsChangeConsumer(prompts -> promptsNotificationReceived.set(true)) + .resourcesUpdateConsumer(resources -> resourcesUpdatedNotificationReceived.set(true)), client -> { assertThatCode(() -> { diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java index 1e87d4420..d6677ec9a 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java @@ -7,7 +7,6 @@ import java.time.Duration; import java.util.List; -import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; @@ -26,6 +25,7 @@ import reactor.core.publisher.Mono; import reactor.test.StepVerifier; +import static io.modelcontextprotocol.util.ToolsUtils.EMPTY_JSON_SCHEMA; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -89,30 +89,27 @@ void testGracefulShutdown() { void testImmediateClose() { var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - assertThatCode(() -> mcpAsyncServer.close()).doesNotThrowAnyException(); + assertThatCode(mcpAsyncServer::close).doesNotThrowAnyException(); } // --------------------------------------- // Tools Tests // --------------------------------------- - String emptyJsonSchema = """ - { - "$schema": "http://json-schema.org/draft-07/schema#", - "type": "object", - "properties": {} - } - """; - @Test @Deprecated void testAddTool() { - Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema); + Tool newTool = McpSchema.Tool.builder() + .name("new-tool") + .title("New test tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); - StepVerifier.create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolSpecification(newTool, - (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))))) + StepVerifier + .create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolSpecification(newTool, + (exchange, args) -> Mono.just(CallToolResult.builder().content(List.of()).isError(false).build())))) .verifyComplete(); assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); @@ -120,14 +117,20 @@ void testAddTool() { @Test void testAddToolCall() { - Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema); + Tool newTool = McpSchema.Tool.builder() + .name("new-tool") + .title("New test tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); StepVerifier.create(mcpAsyncServer.addTool(McpServerFeatures.AsyncToolSpecification.builder() .tool(newTool) - .callHandler((exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) + .callHandler((exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) .build())).verifyComplete(); assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); @@ -136,68 +139,88 @@ void testAddToolCall() { @Test @Deprecated void testAddDuplicateTool() { - Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); + Tool duplicateTool = McpSchema.Tool.builder() + .name(TEST_TOOL_NAME) + .title("Duplicate tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(duplicateTool, (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))) + .tool(duplicateTool, + (exchange, args) -> Mono.just(CallToolResult.builder().content(List.of()).isError(false).build())) .build(); StepVerifier .create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolSpecification(duplicateTool, - (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))))) - .verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Tool with name '" + TEST_TOOL_NAME + "' already exists"); - }); + (exchange, args) -> Mono.just(CallToolResult.builder().content(List.of()).isError(false).build())))) + .verifyComplete(); assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); } @Test void testAddDuplicateToolCall() { - Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); + Tool duplicateTool = McpSchema.Tool.builder() + .name(TEST_TOOL_NAME) + .title("Duplicate tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .toolCall(duplicateTool, (exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) + .toolCall(duplicateTool, + (exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) .build(); StepVerifier.create(mcpAsyncServer.addTool(McpServerFeatures.AsyncToolSpecification.builder() .tool(duplicateTool) - .callHandler((exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) - .build())).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Tool with name '" + TEST_TOOL_NAME + "' already exists"); - }); + .callHandler((exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) + .build())).verifyComplete(); assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); } @Test void testDuplicateToolCallDuringBuilding() { - Tool duplicateTool = new Tool("duplicate-build-toolcall", "Duplicate toolcall during building", - emptyJsonSchema); + Tool duplicateTool = McpSchema.Tool.builder() + .name("duplicate-build-toolcall") + .title("Duplicate toolcall during building") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); assertThatThrownBy(() -> prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .toolCall(duplicateTool, (exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) - .toolCall(duplicateTool, (exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) // Duplicate! + .toolCall(duplicateTool, + (exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) + .toolCall(duplicateTool, + (exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) // Duplicate! .build()).isInstanceOf(IllegalArgumentException.class) .hasMessage("Tool with name 'duplicate-build-toolcall' is already registered."); } @Test void testDuplicateToolsInBatchListRegistration() { - Tool duplicateTool = new Tool("batch-list-tool", "Duplicate tool in batch list", emptyJsonSchema); + Tool duplicateTool = McpSchema.Tool.builder() + .name("batch-list-tool") + .title("Duplicate tool in batch list") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + List specs = List.of( McpServerFeatures.AsyncToolSpecification.builder() .tool(duplicateTool) - .callHandler((exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) + .callHandler((exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) .build(), McpServerFeatures.AsyncToolSpecification.builder() .tool(duplicateTool) - .callHandler((exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) + .callHandler((exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) .build() // Duplicate! ); @@ -210,17 +233,23 @@ void testDuplicateToolsInBatchListRegistration() { @Test void testDuplicateToolsInBatchVarargsRegistration() { - Tool duplicateTool = new Tool("batch-varargs-tool", "Duplicate tool in batch varargs", emptyJsonSchema); + Tool duplicateTool = McpSchema.Tool.builder() + .name("batch-varargs-tool") + .title("Duplicate tool in batch varargs") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); assertThatThrownBy(() -> prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .tools(McpServerFeatures.AsyncToolSpecification.builder() .tool(duplicateTool) - .callHandler((exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) + .callHandler((exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) .build(), McpServerFeatures.AsyncToolSpecification.builder() .tool(duplicateTool) - .callHandler((exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) + .callHandler((exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) .build() // Duplicate! ) .build()).isInstanceOf(IllegalArgumentException.class) @@ -229,11 +258,17 @@ void testDuplicateToolsInBatchVarargsRegistration() { @Test void testRemoveTool() { - Tool too = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); + Tool too = McpSchema.Tool.builder() + .name(TEST_TOOL_NAME) + .title("Duplicate tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .toolCall(too, (exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) + .toolCall(too, + (exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) .build(); StepVerifier.create(mcpAsyncServer.removeTool(TEST_TOOL_NAME)).verifyComplete(); @@ -247,20 +282,23 @@ void testRemoveNonexistentTool() { .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); - StepVerifier.create(mcpAsyncServer.removeTool("nonexistent-tool")).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class).hasMessage("Tool with name 'nonexistent-tool' not found"); - }); + StepVerifier.create(mcpAsyncServer.removeTool("nonexistent-tool")).verifyComplete(); assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); } @Test void testNotifyToolsListChanged() { - Tool too = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); + Tool too = McpSchema.Tool.builder() + .name(TEST_TOOL_NAME) + .title("Duplicate tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .toolCall(too, (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))) + .toolCall(too, + (exchange, args) -> Mono.just(CallToolResult.builder().content(List.of()).isError(false).build())) .build(); StepVerifier.create(mcpAsyncServer.notifyToolsListChanged()).verifyComplete(); @@ -299,8 +337,13 @@ void testAddResource() { .capabilities(ServerCapabilities.builder().resources(true, false).build()) .build(); - Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", - null); + Resource resource = Resource.builder() + .uri(TEST_RESOURCE_URI) + .name("Test Resource") + .title("Test Resource") + .mimeType("text/plain") + .description("Test resource description") + .build(); McpServerFeatures.AsyncResourceSpecification specification = new McpServerFeatures.AsyncResourceSpecification( resource, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); @@ -317,7 +360,7 @@ void testAddResourceWithNullSpecification() { StepVerifier.create(mcpAsyncServer.addResource((McpServerFeatures.AsyncResourceSpecification) null)) .verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class).hasMessage("Resource must not be null"); + assertThat(error).isInstanceOf(IllegalArgumentException.class).hasMessage("Resource must not be null"); }); assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); @@ -328,14 +371,19 @@ void testAddResourceWithoutCapability() { // Create a server without resource capabilities McpAsyncServer serverWithoutResources = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", - null); + Resource resource = Resource.builder() + .uri(TEST_RESOURCE_URI) + .name("Test Resource") + .title("Test Resource") + .mimeType("text/plain") + .description("Test resource description") + .build(); McpServerFeatures.AsyncResourceSpecification specification = new McpServerFeatures.AsyncResourceSpecification( resource, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); StepVerifier.create(serverWithoutResources.addResource(specification)).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Server must be configured with resource capabilities"); + assertThat(error).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Server must be configured with resource capabilities"); }); } @@ -345,11 +393,191 @@ void testRemoveResourceWithoutCapability() { McpAsyncServer serverWithoutResources = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); StepVerifier.create(serverWithoutResources.removeResource(TEST_RESOURCE_URI)).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Server must be configured with resource capabilities"); + assertThat(error).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Server must be configured with resource capabilities"); + }); + } + + @Test + void testListResources() { + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + Resource resource = Resource.builder() + .uri(TEST_RESOURCE_URI) + .name("Test Resource") + .title("Test Resource") + .mimeType("text/plain") + .description("Test resource description") + .build(); + McpServerFeatures.AsyncResourceSpecification specification = new McpServerFeatures.AsyncResourceSpecification( + resource, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + StepVerifier + .create(mcpAsyncServer.addResource(specification).then(mcpAsyncServer.listResources().collectList())) + .expectNextMatches(resources -> resources.size() == 1 && resources.get(0).uri().equals(TEST_RESOURCE_URI)) + .verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testRemoveResource() { + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + Resource resource = Resource.builder() + .uri(TEST_RESOURCE_URI) + .name("Test Resource") + .title("Test Resource") + .mimeType("text/plain") + .description("Test resource description") + .build(); + McpServerFeatures.AsyncResourceSpecification specification = new McpServerFeatures.AsyncResourceSpecification( + resource, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + StepVerifier + .create(mcpAsyncServer.addResource(specification).then(mcpAsyncServer.removeResource(TEST_RESOURCE_URI))) + .verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testRemoveNonexistentResource() { + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + // Removing a non-existent resource should complete successfully (no error) + // as per the new implementation that just logs a warning + StepVerifier.create(mcpAsyncServer.removeResource("nonexistent://resource")).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Resource Template Tests + // --------------------------------------- + + @Test + void testAddResourceTemplate() { + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + McpSchema.ResourceTemplate template = McpSchema.ResourceTemplate.builder() + .uriTemplate("test://template/{id}") + .name("test-template") + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.AsyncResourceTemplateSpecification specification = new McpServerFeatures.AsyncResourceTemplateSpecification( + template, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + StepVerifier.create(mcpAsyncServer.addResourceTemplate(specification)).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testAddResourceTemplateWithoutCapability() { + // Create a server without resource capabilities + McpAsyncServer serverWithoutResources = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + McpSchema.ResourceTemplate template = McpSchema.ResourceTemplate.builder() + .uriTemplate("test://template/{id}") + .name("test-template") + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.AsyncResourceTemplateSpecification specification = new McpServerFeatures.AsyncResourceTemplateSpecification( + template, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + StepVerifier.create(serverWithoutResources.addResourceTemplate(specification)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Server must be configured with resource capabilities"); }); } + @Test + void testRemoveResourceTemplate() { + McpSchema.ResourceTemplate template = McpSchema.ResourceTemplate.builder() + .uriTemplate("test://template/{id}") + .name("test-template") + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.AsyncResourceTemplateSpecification specification = new McpServerFeatures.AsyncResourceTemplateSpecification( + template, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .resourceTemplates(specification) + .build(); + + StepVerifier.create(mcpAsyncServer.removeResourceTemplate("test://template/{id}")).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testRemoveResourceTemplateWithoutCapability() { + // Create a server without resource capabilities + McpAsyncServer serverWithoutResources = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + StepVerifier.create(serverWithoutResources.removeResourceTemplate("test://template/{id}")) + .verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Server must be configured with resource capabilities"); + }); + } + + @Test + void testRemoveNonexistentResourceTemplate() { + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + StepVerifier.create(mcpAsyncServer.removeResourceTemplate("nonexistent://template/{id}")).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testListResourceTemplates() { + McpSchema.ResourceTemplate template = McpSchema.ResourceTemplate.builder() + .uriTemplate("test://template/{id}") + .name("test-template") + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.AsyncResourceTemplateSpecification specification = new McpServerFeatures.AsyncResourceTemplateSpecification( + template, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .resourceTemplates(specification) + .build(); + + // Note: Based on the current implementation, listResourceTemplates() returns + // Flux + // This appears to be a bug in the implementation that should return + // Flux + StepVerifier.create(mcpAsyncServer.listResourceTemplates().collectList()) + .expectNextMatches(resources -> resources.size() >= 0) // Just verify it + // doesn't error + .verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + // --------------------------------------- // Prompts Tests // --------------------------------------- @@ -371,7 +599,8 @@ void testAddPromptWithNullSpecification() { StepVerifier.create(mcpAsyncServer.addPrompt((McpServerFeatures.AsyncPromptSpecification) null)) .verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class).hasMessage("Prompt specification must not be null"); + assertThat(error).isInstanceOf(IllegalArgumentException.class) + .hasMessage("Prompt specification must not be null"); }); } @@ -386,7 +615,7 @@ void testAddPromptWithoutCapability() { .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content")))))); StepVerifier.create(serverWithoutPrompts.addPrompt(specification)).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) + assertThat(error).isInstanceOf(IllegalStateException.class) .hasMessage("Server must be configured with prompt capabilities"); }); } @@ -397,7 +626,7 @@ void testRemovePromptWithoutCapability() { McpAsyncServer serverWithoutPrompts = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); StepVerifier.create(serverWithoutPrompts.removePrompt(TEST_PROMPT_NAME)).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) + assertThat(error).isInstanceOf(IllegalStateException.class) .hasMessage("Server must be configured with prompt capabilities"); }); } @@ -427,10 +656,7 @@ void testRemoveNonexistentPrompt() { .capabilities(ServerCapabilities.builder().prompts(true).build()) .build(); - StepVerifier.create(mcpAsyncServer2.removePrompt("nonexistent-prompt")).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Prompt with name 'nonexistent-prompt' not found"); - }); + StepVerifier.create(mcpAsyncServer2.removePrompt("nonexistent-prompt")).verifyComplete(); assertThatCode(() -> mcpAsyncServer2.closeGracefully().block(Duration.ofSeconds(10))) .doesNotThrowAnyException(); diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java index 5d70ae4c0..0a59d0aae 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java @@ -4,17 +4,8 @@ package io.modelcontextprotocol.server; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; -import static org.assertj.core.api.Assertions.assertThatThrownBy; - import java.util.List; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; - -import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; @@ -25,6 +16,14 @@ import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; import io.modelcontextprotocol.spec.McpSchema.Tool; import io.modelcontextprotocol.spec.McpServerTransportProvider; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import static io.modelcontextprotocol.util.ToolsUtils.EMPTY_JSON_SCHEMA; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatThrownBy; /** * Test suite for the {@link McpSyncServer} that can be used with different @@ -77,14 +76,14 @@ void testConstructorWithInvalidArguments() { void testGracefulShutdown() { var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } @Test void testImmediateClose() { var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - assertThatCode(() -> mcpSyncServer.close()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::close).doesNotThrowAnyException(); } @Test @@ -93,21 +92,13 @@ void testGetAsyncServer() { assertThat(mcpSyncServer.getAsyncServer()).isNotNull(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } // --------------------------------------- // Tools Tests // --------------------------------------- - String emptyJsonSchema = """ - { - "$schema": "http://json-schema.org/draft-07/schema#", - "type": "object", - "properties": {} - } - """; - @Test @Deprecated void testAddTool() { @@ -115,12 +106,16 @@ void testAddTool() { .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); - Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema); + Tool newTool = McpSchema.Tool.builder() + .name("new-tool") + .title("New test tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); assertThatCode(() -> mcpSyncServer.addTool(new McpServerFeatures.SyncToolSpecification(newTool, - (exchange, args) -> new CallToolResult(List.of(), false)))) + (exchange, args) -> CallToolResult.builder().content(List.of()).isError(false).build()))) .doesNotThrowAnyException(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } @Test @@ -129,75 +124,98 @@ void testAddToolCall() { .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); - Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema); + Tool newTool = McpSchema.Tool.builder() + .name("new-tool") + .title("New test tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + assertThatCode(() -> mcpSyncServer.addTool(McpServerFeatures.SyncToolSpecification.builder() .tool(newTool) - .callHandler((exchange, request) -> new CallToolResult(List.of(), false)) + .callHandler((exchange, request) -> CallToolResult.builder().content(List.of()).isError(false).build()) .build())).doesNotThrowAnyException(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } @Test @Deprecated void testAddDuplicateTool() { - Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); + Tool duplicateTool = McpSchema.Tool.builder() + .name(TEST_TOOL_NAME) + .title("Duplicate tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(duplicateTool, (exchange, args) -> new CallToolResult(List.of(), false)) + .tool(duplicateTool, (exchange, args) -> CallToolResult.builder().content(List.of()).isError(false).build()) .build(); - assertThatThrownBy(() -> mcpSyncServer.addTool(new McpServerFeatures.SyncToolSpecification(duplicateTool, - (exchange, args) -> new CallToolResult(List.of(), false)))) - .isInstanceOf(McpError.class) - .hasMessage("Tool with name '" + TEST_TOOL_NAME + "' already exists"); + assertThatCode(() -> mcpSyncServer.addTool(new McpServerFeatures.SyncToolSpecification(duplicateTool, + (exchange, args) -> CallToolResult.builder().content(List.of()).isError(false).build()))) + .doesNotThrowAnyException(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } @Test void testAddDuplicateToolCall() { - Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); + Tool duplicateTool = McpSchema.Tool.builder() + .name(TEST_TOOL_NAME) + .title("Duplicate tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .toolCall(duplicateTool, (exchange, request) -> new CallToolResult(List.of(), false)) + .toolCall(duplicateTool, + (exchange, request) -> CallToolResult.builder().content(List.of()).isError(false).build()) .build(); - assertThatThrownBy(() -> mcpSyncServer.addTool(McpServerFeatures.SyncToolSpecification.builder() + assertThatCode(() -> mcpSyncServer.addTool(McpServerFeatures.SyncToolSpecification.builder() .tool(duplicateTool) - .callHandler((exchange, request) -> new CallToolResult(List.of(), false)) - .build())).isInstanceOf(McpError.class) - .hasMessage("Tool with name '" + TEST_TOOL_NAME + "' already exists"); + .callHandler((exchange, request) -> CallToolResult.builder().content(List.of()).isError(false).build()) + .build())).doesNotThrowAnyException(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } @Test void testDuplicateToolCallDuringBuilding() { - Tool duplicateTool = new Tool("duplicate-build-toolcall", "Duplicate toolcall during building", - emptyJsonSchema); + Tool duplicateTool = McpSchema.Tool.builder() + .name("duplicate-build-toolcall") + .title("Duplicate toolcall during building") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); assertThatThrownBy(() -> prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .toolCall(duplicateTool, (exchange, request) -> new CallToolResult(List.of(), false)) - .toolCall(duplicateTool, (exchange, request) -> new CallToolResult(List.of(), false)) // Duplicate! + .toolCall(duplicateTool, + (exchange, request) -> CallToolResult.builder().content(List.of()).isError(false).build()) + .toolCall(duplicateTool, + (exchange, request) -> CallToolResult.builder().content(List.of()).isError(false).build()) // Duplicate! .build()).isInstanceOf(IllegalArgumentException.class) .hasMessage("Tool with name 'duplicate-build-toolcall' is already registered."); } @Test void testDuplicateToolsInBatchListRegistration() { - Tool duplicateTool = new Tool("batch-list-tool", "Duplicate tool in batch list", emptyJsonSchema); + Tool duplicateTool = McpSchema.Tool.builder() + .name("batch-list-tool") + .title("Duplicate tool in batch list") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); List specs = List.of( McpServerFeatures.SyncToolSpecification.builder() .tool(duplicateTool) - .callHandler((exchange, request) -> new CallToolResult(List.of(), false)) + .callHandler( + (exchange, request) -> CallToolResult.builder().content(List.of()).isError(false).build()) .build(), McpServerFeatures.SyncToolSpecification.builder() .tool(duplicateTool) - .callHandler((exchange, request) -> new CallToolResult(List.of(), false)) + .callHandler( + (exchange, request) -> CallToolResult.builder().content(List.of()).isError(false).build()) .build() // Duplicate! ); @@ -210,17 +228,22 @@ void testDuplicateToolsInBatchListRegistration() { @Test void testDuplicateToolsInBatchVarargsRegistration() { - Tool duplicateTool = new Tool("batch-varargs-tool", "Duplicate tool in batch varargs", emptyJsonSchema); + Tool duplicateTool = McpSchema.Tool.builder() + .name("batch-varargs-tool") + .title("Duplicate tool in batch varargs") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); assertThatThrownBy(() -> prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .tools(McpServerFeatures.SyncToolSpecification.builder() .tool(duplicateTool) - .callHandler((exchange, request) -> new CallToolResult(List.of(), false)) + .callHandler((exchange, request) -> CallToolResult.builder().content(List.of()).isError(false).build()) .build(), McpServerFeatures.SyncToolSpecification.builder() .tool(duplicateTool) - .callHandler((exchange, request) -> new CallToolResult(List.of(), false)) + .callHandler((exchange, + request) -> CallToolResult.builder().content(List.of()).isError(false).build()) .build() // Duplicate! ) .build()).isInstanceOf(IllegalArgumentException.class) @@ -229,16 +252,20 @@ void testDuplicateToolsInBatchVarargsRegistration() { @Test void testRemoveTool() { - Tool tool = new McpSchema.Tool(TEST_TOOL_NAME, "Test tool", emptyJsonSchema); + Tool tool = McpSchema.Tool.builder() + .name(TEST_TOOL_NAME) + .title("Test tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .toolCall(tool, (exchange, args) -> new CallToolResult(List.of(), false)) + .toolCall(tool, (exchange, args) -> CallToolResult.builder().content(List.of()).isError(false).build()) .build(); assertThatCode(() -> mcpSyncServer.removeTool(TEST_TOOL_NAME)).doesNotThrowAnyException(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } @Test @@ -247,19 +274,18 @@ void testRemoveNonexistentTool() { .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); - assertThatThrownBy(() -> mcpSyncServer.removeTool("nonexistent-tool")).isInstanceOf(McpError.class) - .hasMessage("Tool with name 'nonexistent-tool' not found"); + assertThatCode(() -> mcpSyncServer.removeTool("nonexistent-tool")).doesNotThrowAnyException(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } @Test void testNotifyToolsListChanged() { var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - assertThatCode(() -> mcpSyncServer.notifyToolsListChanged()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::notifyToolsListChanged).doesNotThrowAnyException(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } // --------------------------------------- @@ -270,9 +296,9 @@ void testNotifyToolsListChanged() { void testNotifyResourcesListChanged() { var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - assertThatCode(() -> mcpSyncServer.notifyResourcesListChanged()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::notifyResourcesListChanged).doesNotThrowAnyException(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } @Test @@ -283,7 +309,7 @@ void testNotifyResourcesUpdated() { .notifyResourcesUpdated(new McpSchema.ResourcesUpdatedNotification(TEST_RESOURCE_URI))) .doesNotThrowAnyException(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } @Test @@ -292,14 +318,19 @@ void testAddResource() { .capabilities(ServerCapabilities.builder().resources(true, false).build()) .build(); - Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", - null); + Resource resource = Resource.builder() + .uri(TEST_RESOURCE_URI) + .name("Test Resource") + .title("Test Resource") + .mimeType("text/plain") + .description("Test resource description") + .build(); McpServerFeatures.SyncResourceSpecification specification = new McpServerFeatures.SyncResourceSpecification( resource, (exchange, req) -> new ReadResourceResult(List.of())); assertThatCode(() -> mcpSyncServer.addResource(specification)).doesNotThrowAnyException(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } @Test @@ -309,31 +340,211 @@ void testAddResourceWithNullSpecification() { .build(); assertThatThrownBy(() -> mcpSyncServer.addResource((McpServerFeatures.SyncResourceSpecification) null)) - .isInstanceOf(McpError.class) + .isInstanceOf(IllegalArgumentException.class) .hasMessage("Resource must not be null"); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } @Test void testAddResourceWithoutCapability() { var serverWithoutResources = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", - null); + Resource resource = Resource.builder() + .uri(TEST_RESOURCE_URI) + .name("Test Resource") + .title("Test Resource") + .mimeType("text/plain") + .description("Test resource description") + .build(); McpServerFeatures.SyncResourceSpecification specification = new McpServerFeatures.SyncResourceSpecification( resource, (exchange, req) -> new ReadResourceResult(List.of())); - assertThatThrownBy(() -> serverWithoutResources.addResource(specification)).isInstanceOf(McpError.class) - .hasMessage("Server must be configured with resource capabilities"); + assertThatThrownBy(() -> serverWithoutResources.addResource(specification)) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Server must be configured with resource capabilities"); } @Test void testRemoveResourceWithoutCapability() { var serverWithoutResources = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - assertThatThrownBy(() -> serverWithoutResources.removeResource(TEST_RESOURCE_URI)).isInstanceOf(McpError.class) - .hasMessage("Server must be configured with resource capabilities"); + assertThatThrownBy(() -> serverWithoutResources.removeResource(TEST_RESOURCE_URI)) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Server must be configured with resource capabilities"); + } + + @Test + void testListResources() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + Resource resource = Resource.builder() + .uri(TEST_RESOURCE_URI) + .name("Test Resource") + .title("Test Resource") + .mimeType("text/plain") + .description("Test resource description") + .build(); + McpServerFeatures.SyncResourceSpecification specification = new McpServerFeatures.SyncResourceSpecification( + resource, (exchange, req) -> new ReadResourceResult(List.of())); + + mcpSyncServer.addResource(specification); + List resources = mcpSyncServer.listResources(); + + assertThat(resources).hasSize(1); + assertThat(resources.get(0).uri()).isEqualTo(TEST_RESOURCE_URI); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testRemoveResource() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + Resource resource = Resource.builder() + .uri(TEST_RESOURCE_URI) + .name("Test Resource") + .title("Test Resource") + .mimeType("text/plain") + .description("Test resource description") + .build(); + McpServerFeatures.SyncResourceSpecification specification = new McpServerFeatures.SyncResourceSpecification( + resource, (exchange, req) -> new ReadResourceResult(List.of())); + + mcpSyncServer.addResource(specification); + assertThatCode(() -> mcpSyncServer.removeResource(TEST_RESOURCE_URI)).doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testRemoveNonexistentResource() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + // Removing a non-existent resource should complete successfully (no error) + // as per the new implementation that just logs a warning + assertThatCode(() -> mcpSyncServer.removeResource("nonexistent://resource")).doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Resource Template Tests + // --------------------------------------- + + @Test + void testAddResourceTemplate() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + McpSchema.ResourceTemplate template = McpSchema.ResourceTemplate.builder() + .uriTemplate("test://template/{id}") + .name("test-template") + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.SyncResourceTemplateSpecification specification = new McpServerFeatures.SyncResourceTemplateSpecification( + template, (exchange, req) -> new ReadResourceResult(List.of())); + + assertThatCode(() -> mcpSyncServer.addResourceTemplate(specification)).doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testAddResourceTemplateWithoutCapability() { + // Create a server without resource capabilities + var serverWithoutResources = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + McpSchema.ResourceTemplate template = McpSchema.ResourceTemplate.builder() + .uriTemplate("test://template/{id}") + .name("test-template") + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.SyncResourceTemplateSpecification specification = new McpServerFeatures.SyncResourceTemplateSpecification( + template, (exchange, req) -> new ReadResourceResult(List.of())); + + assertThatThrownBy(() -> serverWithoutResources.addResourceTemplate(specification)) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Server must be configured with resource capabilities"); + } + + @Test + void testRemoveResourceTemplate() { + McpSchema.ResourceTemplate template = McpSchema.ResourceTemplate.builder() + .uriTemplate("test://template/{id}") + .name("test-template") + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.SyncResourceTemplateSpecification specification = new McpServerFeatures.SyncResourceTemplateSpecification( + template, (exchange, req) -> new ReadResourceResult(List.of())); + + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .resourceTemplates(specification) + .build(); + + assertThatCode(() -> mcpSyncServer.removeResourceTemplate("test://template/{id}")).doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testRemoveResourceTemplateWithoutCapability() { + // Create a server without resource capabilities + var serverWithoutResources = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + assertThatThrownBy(() -> serverWithoutResources.removeResourceTemplate("test://template/{id}")) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Server must be configured with resource capabilities"); + } + + @Test + void testRemoveNonexistentResourceTemplate() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + assertThatCode(() -> mcpSyncServer.removeResourceTemplate("nonexistent://template/{id}")) + .doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testListResourceTemplates() { + McpSchema.ResourceTemplate template = McpSchema.ResourceTemplate.builder() + .uriTemplate("test://template/{id}") + .name("test-template") + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.SyncResourceTemplateSpecification specification = new McpServerFeatures.SyncResourceTemplateSpecification( + template, (exchange, req) -> new ReadResourceResult(List.of())); + + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .resourceTemplates(specification) + .build(); + + List templates = mcpSyncServer.listResourceTemplates(); + + assertThat(templates).isNotNull(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } // --------------------------------------- @@ -344,9 +555,9 @@ void testRemoveResourceWithoutCapability() { void testNotifyPromptsListChanged() { var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - assertThatCode(() -> mcpSyncServer.notifyPromptsListChanged()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::notifyPromptsListChanged).doesNotThrowAnyException(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } @Test @@ -356,7 +567,7 @@ void testAddPromptWithNullSpecification() { .build(); assertThatThrownBy(() -> mcpSyncServer.addPrompt((McpServerFeatures.SyncPromptSpecification) null)) - .isInstanceOf(McpError.class) + .isInstanceOf(IllegalArgumentException.class) .hasMessage("Prompt specification must not be null"); } @@ -369,7 +580,8 @@ void testAddPromptWithoutCapability() { (exchange, req) -> new GetPromptResult("Test prompt description", List .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content"))))); - assertThatThrownBy(() -> serverWithoutPrompts.addPrompt(specification)).isInstanceOf(McpError.class) + assertThatThrownBy(() -> serverWithoutPrompts.addPrompt(specification)) + .isInstanceOf(IllegalStateException.class) .hasMessage("Server must be configured with prompt capabilities"); } @@ -377,7 +589,8 @@ void testAddPromptWithoutCapability() { void testRemovePromptWithoutCapability() { var serverWithoutPrompts = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - assertThatThrownBy(() -> serverWithoutPrompts.removePrompt(TEST_PROMPT_NAME)).isInstanceOf(McpError.class) + assertThatThrownBy(() -> serverWithoutPrompts.removePrompt(TEST_PROMPT_NAME)) + .isInstanceOf(IllegalStateException.class) .hasMessage("Server must be configured with prompt capabilities"); } @@ -395,7 +608,7 @@ void testRemovePrompt() { assertThatCode(() -> mcpSyncServer.removePrompt(TEST_PROMPT_NAME)).doesNotThrowAnyException(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } @Test @@ -404,10 +617,9 @@ void testRemoveNonexistentPrompt() { .capabilities(ServerCapabilities.builder().prompts(true).build()) .build(); - assertThatThrownBy(() -> mcpSyncServer.removePrompt("nonexistent-prompt")).isInstanceOf(McpError.class) - .hasMessage("Prompt with name 'nonexistent-prompt' not found"); + assertThatCode(() -> mcpSyncServer.removePrompt("nonexistent://template/{id}")).doesNotThrowAnyException(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } // --------------------------------------- @@ -428,9 +640,8 @@ void testRootsChangeHandlers() { } })) .build(); - assertThat(singleConsumerServer).isNotNull(); - assertThatCode(() -> singleConsumerServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(singleConsumerServer::closeGracefully).doesNotThrowAnyException(); onClose(); // Test with multiple consumers @@ -446,7 +657,7 @@ void testRootsChangeHandlers() { .build(); assertThat(multipleConsumersServer).isNotNull(); - assertThatCode(() -> multipleConsumersServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(multipleConsumersServer::closeGracefully).doesNotThrowAnyException(); onClose(); // Test error handling @@ -457,14 +668,14 @@ void testRootsChangeHandlers() { .build(); assertThat(errorHandlingServer).isNotNull(); - assertThatCode(() -> errorHandlingServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(errorHandlingServer::closeGracefully).doesNotThrowAnyException(); onClose(); // Test without consumers var noConsumersServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); assertThat(noConsumersServer).isNotNull(); - assertThatCode(() -> noConsumersServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(noConsumersServer::closeGracefully).doesNotThrowAnyException(); } } diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/util/McpJsonMapperUtils.java b/mcp-test/src/main/java/io/modelcontextprotocol/util/McpJsonMapperUtils.java new file mode 100644 index 000000000..723965519 --- /dev/null +++ b/mcp-test/src/main/java/io/modelcontextprotocol/util/McpJsonMapperUtils.java @@ -0,0 +1,12 @@ +package io.modelcontextprotocol.util; + +import io.modelcontextprotocol.json.McpJsonMapper; + +public final class McpJsonMapperUtils { + + private McpJsonMapperUtils() { + } + + public static final McpJsonMapper JSON_MAPPER = McpJsonMapper.getDefault(); + +} \ No newline at end of file diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/util/ToolsUtils.java b/mcp-test/src/main/java/io/modelcontextprotocol/util/ToolsUtils.java new file mode 100644 index 000000000..ce8755223 --- /dev/null +++ b/mcp-test/src/main/java/io/modelcontextprotocol/util/ToolsUtils.java @@ -0,0 +1,15 @@ +package io.modelcontextprotocol.util; + +import io.modelcontextprotocol.spec.McpSchema; + +import java.util.Collections; + +public final class ToolsUtils { + + private ToolsUtils() { + } + + public static final McpSchema.JsonSchema EMPTY_JSON_SCHEMA = new McpSchema.JsonSchema("object", + Collections.emptyMap(), null, null, null, null); + +} diff --git a/mcp/pom.xml b/mcp/pom.xml index 1cf61c48f..0e0ed1288 100644 --- a/mcp/pom.xml +++ b/mcp/pom.xml @@ -6,7 +6,7 @@ io.modelcontextprotocol.sdk mcp-parent - 0.12.0-SNAPSHOT + 0.18.0-SNAPSHOT mcp jar @@ -20,204 +20,20 @@ git@github.com/modelcontextprotocol/java-sdk.git - - - - biz.aQute.bnd - bnd-maven-plugin - ${bnd-maven-plugin.version} - - - bnd-process - - bnd-process - - - - - - - - - - - org.apache.maven.plugins - maven-jar-plugin - - - ${project.build.outputDirectory}/META-INF/MANIFEST.MF - - - - - - - org.slf4j - slf4j-api - ${slf4j-api.version} - - - - com.fasterxml.jackson.core - jackson-databind - ${jackson.version} - - - - io.projectreactor - reactor-core - - - - com.networknt - json-schema-validator - ${json-schema-validator.version} - - - - - jakarta.servlet - jakarta.servlet-api - ${jakarta.servlet.version} - provided - - - - - - org.springframework - spring-webmvc - ${springframework.version} - test - - - - - io.projectreactor.netty - reactor-netty-http - test - - - - - org.springframework - spring-context - ${springframework.version} - test - - - - org.springframework - spring-test - ${springframework.version} - test - - - - org.assertj - assertj-core - ${assert4j.version} - test - - - org.junit.jupiter - junit-jupiter-api - ${junit.version} - test - - - org.junit.jupiter - junit-jupiter-params - ${junit.version} - test - - - org.mockito - mockito-core - ${mockito.version} - test - - - - - net.bytebuddy - byte-buddy - ${byte-buddy.version} - test - - - io.projectreactor - reactor-test - test - - - org.testcontainers - junit-jupiter - ${testcontainers.version} - test - - - - org.awaitility - awaitility - ${awaitility.version} - test - - - - ch.qos.logback - logback-classic - ${logback.version} - test - - - - net.javacrumbs.json-unit - json-unit-assertj - ${json-unit-assertj.version} - test + io.modelcontextprotocol.sdk + mcp-json-jackson2 + 0.18.0-SNAPSHOT - - org.apache.tomcat.embed - tomcat-embed-core - ${tomcat.version} - test + io.modelcontextprotocol.sdk + mcp-core + 0.18.0-SNAPSHOT - - org.apache.tomcat.embed - tomcat-embed-websocket - ${tomcat.version} - test - - - - org.testcontainers - toxiproxy - ${toxiproxy.version} - test - - - - \ No newline at end of file + diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/SyncHttpRequestCustomizer.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/SyncHttpRequestCustomizer.java deleted file mode 100644 index 72b6e6c1b..000000000 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/SyncHttpRequestCustomizer.java +++ /dev/null @@ -1,21 +0,0 @@ -/* - * Copyright 2024-2025 the original author or authors. - */ - -package io.modelcontextprotocol.client.transport; - -import java.net.URI; -import java.net.http.HttpRequest; -import reactor.util.annotation.Nullable; - -/** - * Customize {@link HttpRequest.Builder} before executing the request, either in SSE or - * Streamable HTTP transport. - * - * @author Daniel Garnier-Moiroux - */ -public interface SyncHttpRequestCustomizer { - - void customize(HttpRequest.Builder builder, String method, URI endpoint, @Nullable String body); - -} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/DefaultMcpTransportContext.java b/mcp/src/main/java/io/modelcontextprotocol/server/DefaultMcpTransportContext.java deleted file mode 100644 index 9e18e189d..000000000 --- a/mcp/src/main/java/io/modelcontextprotocol/server/DefaultMcpTransportContext.java +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Copyright 2024-2025 the original author or authors. - */ - -package io.modelcontextprotocol.server; - -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; - -/** - * Default implementation for {@link McpTransportContext} which uses a Thread-safe map. - * Objects of this kind are mutable. - * - * @author Dariusz Jędrzejczyk - */ -public class DefaultMcpTransportContext implements McpTransportContext { - - private final Map storage; - - /** - * Create an empty instance. - */ - public DefaultMcpTransportContext() { - this.storage = new ConcurrentHashMap<>(); - } - - DefaultMcpTransportContext(Map storage) { - this.storage = storage; - } - - @Override - public Object get(String key) { - return this.storage.get(key); - } - - @Override - public void put(String key, Object value) { - this.storage.put(key, value); - } - - /** - * Allows copying the contents. - * @return new instance with the copy of the underlying map - */ - public McpTransportContext copy() { - return new DefaultMcpTransportContext(new ConcurrentHashMap<>(this.storage)); - } - -} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/HttpHeaders.java b/mcp/src/main/java/io/modelcontextprotocol/spec/HttpHeaders.java deleted file mode 100644 index 65b80957c..000000000 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/HttpHeaders.java +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Copyright 2024-2025 the original author or authors. - */ - -package io.modelcontextprotocol.spec; - -/** - * Names of HTTP headers in use by MCP HTTP transports. - * - * @author Dariusz Jędrzejczyk - */ -public interface HttpHeaders { - - /** - * Identifies individual MCP sessions. - */ - String MCP_SESSION_ID = "mcp-session-id"; - - /** - * Identifies events within an SSE Stream. - */ - String LAST_EVENT_ID = "Last-Event-ID"; - - /** - * Identifies the MCP protocol version. - */ - String PROTOCOL_VERSION = "MCP-Protocol-Version"; - -} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpError.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpError.java deleted file mode 100644 index 6172d8637..000000000 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpError.java +++ /dev/null @@ -1,61 +0,0 @@ -/* -* Copyright 2024 - 2024 the original author or authors. -*/ - -package io.modelcontextprotocol.spec; - -import io.modelcontextprotocol.spec.McpSchema.JSONRPCResponse.JSONRPCError; -import io.modelcontextprotocol.util.Assert; - -public class McpError extends RuntimeException { - - private JSONRPCError jsonRpcError; - - public McpError(JSONRPCError jsonRpcError) { - super(jsonRpcError.message()); - this.jsonRpcError = jsonRpcError; - } - - @Deprecated - public McpError(Object error) { - super(error.toString()); - } - - public JSONRPCError getJsonRpcError() { - return jsonRpcError; - } - - public static Builder builder(int errorCode) { - return new Builder(errorCode); - } - - public static class Builder { - - private final int code; - - private String message; - - private Object data; - - private Builder(int code) { - this.code = code; - } - - public Builder message(String message) { - this.message = message; - return this; - } - - public Builder data(Object data) { - this.data = data; - return this; - } - - public McpError build() { - Assert.hasText(message, "message must not be empty"); - return new McpError(new JSONRPCError(code, message, data)); - } - - } - -} \ No newline at end of file diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpSyncClientTests.java deleted file mode 100644 index 7f00de60e..000000000 --- a/mcp/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpSyncClientTests.java +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Copyright 2024-2025 the original author or authors. - */ - -package io.modelcontextprotocol.client; - -import org.junit.jupiter.api.Timeout; -import org.testcontainers.containers.GenericContainer; -import org.testcontainers.containers.wait.strategy.Wait; - -import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; -import io.modelcontextprotocol.spec.McpClientTransport; - -@Timeout(15) -public class HttpClientStreamableHttpSyncClientTests extends AbstractMcpSyncClientTests { - - static String host = "http://localhost:3001"; - - // Uses the https://github.com/tzolov/mcp-everything-server-docker-image - @SuppressWarnings("resource") - GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") - .withCommand("node dist/index.js streamableHttp") - .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) - .withExposedPorts(3001) - .waitingFor(Wait.forHttp("/").forStatusCode(404)); - - @Override - protected McpClientTransport createMcpTransport() { - return HttpClientStreamableHttpTransport.builder(host).build(); - } - - @Override - protected void onStart() { - container.start(); - int port = container.getMappedPort(3001); - host = "http://" + container.getHost() + ":" + port; - } - - @Override - public void onClose() { - container.stop(); - } - -} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java deleted file mode 100644 index 8646c1b4c..000000000 --- a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.client; - -import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; -import io.modelcontextprotocol.spec.McpClientTransport; -import org.junit.jupiter.api.Timeout; -import org.testcontainers.containers.GenericContainer; -import org.testcontainers.containers.wait.strategy.Wait; - -/** - * Tests for the {@link McpSyncClient} with {@link HttpClientSseClientTransport}. - * - * @author Christian Tzolov - */ -@Timeout(15) // Giving extra time beyond the client timeout -class HttpSseMcpSyncClientTests extends AbstractMcpSyncClientTests { - - String host = "http://localhost:3003"; - - // Uses the https://github.com/tzolov/mcp-everything-server-docker-image - @SuppressWarnings("resource") - GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") - .withCommand("node dist/index.js sse") - .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) - .withExposedPorts(3001) - .waitingFor(Wait.forHttp("/").forStatusCode(404)); - - @Override - protected McpClientTransport createMcpTransport() { - return HttpClientSseClientTransport.builder(host).build(); - } - - @Override - protected void onStart() { - container.start(); - int port = container.getMappedPort(3001); - host = "http://" + container.getHost() + ":" + port; - } - - @Override - protected void onClose() { - container.stop(); - } - -} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientTests.java deleted file mode 100644 index ae33898b7..000000000 --- a/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientTests.java +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Copyright 2024-2025 the original author or authors. - */ - -package io.modelcontextprotocol.client; - -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.spec.McpClientTransport; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.ProtocolVersions; - -import org.junit.jupiter.api.Test; -import reactor.core.publisher.Mono; - -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Function; - -import static org.assertj.core.api.Assertions.assertThatCode; - -class McpAsyncClientTests { - - public static final McpSchema.Implementation MOCK_SERVER_INFO = new McpSchema.Implementation("test-server", - "1.0.0"); - - public static final McpSchema.ServerCapabilities MOCK_SERVER_CAPABILITIES = McpSchema.ServerCapabilities.builder() - .build(); - - public static final McpSchema.InitializeResult MOCK_INIT_RESULT = new McpSchema.InitializeResult( - ProtocolVersions.MCP_2024_11_05, MOCK_SERVER_CAPABILITIES, MOCK_SERVER_INFO, "Test instructions"); - - private static final String CONTEXT_KEY = "context.key"; - - private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); - - @Test - void validateContextPassedToTransportConnect() { - McpClientTransport transport = new McpClientTransport() { - Function, Mono> handler; - - final AtomicReference contextValue = new AtomicReference<>(); - - @Override - public Mono connect( - Function, Mono> handler) { - return Mono.deferContextual(ctx -> { - this.handler = handler; - if (ctx.hasKey(CONTEXT_KEY)) { - this.contextValue.set(ctx.get(CONTEXT_KEY)); - } - return Mono.empty(); - }); - } - - @Override - public Mono closeGracefully() { - return Mono.empty(); - } - - @Override - public Mono sendMessage(McpSchema.JSONRPCMessage message) { - if (!"hello".equals(this.contextValue.get())) { - return Mono.error(new RuntimeException("Context value not propagated via #connect method")); - } - // We're only interested in handling the init request to provide an init - // response - if (!(message instanceof McpSchema.JSONRPCRequest)) { - return Mono.empty(); - } - McpSchema.JSONRPCResponse initResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, - ((McpSchema.JSONRPCRequest) message).id(), MOCK_INIT_RESULT, null); - return handler.apply(Mono.just(initResponse)).then(); - } - - @Override - public T unmarshalFrom(Object data, TypeReference typeRef) { - return OBJECT_MAPPER.convertValue(data, typeRef); - } - }; - - assertThatCode(() -> { - McpAsyncClient client = McpClient.async(transport).build(); - client.initialize().contextWrite(ctx -> ctx.put(CONTEXT_KEY, "hello")).block(); - }).doesNotThrowAnyException(); - } - -} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/McpTestServletFilter.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/McpTestServletFilter.java deleted file mode 100644 index cc2543aa9..000000000 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/McpTestServletFilter.java +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Copyright 2025 - 2025 the original author or authors. - */ - -package io.modelcontextprotocol.server.transport; - -import java.io.IOException; - -import jakarta.servlet.Filter; -import jakarta.servlet.FilterChain; -import jakarta.servlet.ServletException; -import jakarta.servlet.ServletRequest; -import jakarta.servlet.ServletResponse; - -/** - * Simple {@link Filter} which sets a value in a thread local. Used to verify whether MCP - * executions happen on the thread processing the request or are offloaded. - * - * @author Daniel Garnier-Moiroux - */ -public class McpTestServletFilter implements Filter { - - public static final String THREAD_LOCAL_VALUE = McpTestServletFilter.class.getName(); - - private static final ThreadLocal holder = new ThreadLocal<>(); - - @Override - public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) - throws IOException, ServletException { - holder.set(THREAD_LOCAL_VALUE); - try { - filterChain.doFilter(servletRequest, servletResponse); - } - finally { - holder.remove(); - } - } - - public static String getThreadLocalValue() { - return holder.get(); - } - -} diff --git a/mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java b/mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java deleted file mode 100644 index 85dcd26c2..000000000 --- a/mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java +++ /dev/null @@ -1,198 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.spec; - -import java.time.Duration; -import java.util.Map; - -import com.fasterxml.jackson.core.type.TypeReference; -import io.modelcontextprotocol.MockMcpClientTransport; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import reactor.core.publisher.Mono; -import reactor.core.publisher.Sinks; -import reactor.test.StepVerifier; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; - -/** - * Test suite for {@link McpClientSession} that verifies its JSON-RPC message handling, - * request-response correlation, and notification processing. - * - * @author Christian Tzolov - */ -class McpClientSessionTests { - - private static final Logger logger = LoggerFactory.getLogger(McpClientSessionTests.class); - - private static final Duration TIMEOUT = Duration.ofSeconds(5); - - private static final String TEST_METHOD = "test.method"; - - private static final String TEST_NOTIFICATION = "test.notification"; - - private static final String ECHO_METHOD = "echo"; - - private McpClientSession session; - - private MockMcpClientTransport transport; - - @BeforeEach - void setUp() { - transport = new MockMcpClientTransport(); - session = new McpClientSession(TIMEOUT, transport, Map.of(), - Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> logger.info("Status update: {}", params)))); - } - - @AfterEach - void tearDown() { - if (session != null) { - session.close(); - } - } - - @Test - void testConstructorWithInvalidArguments() { - assertThatThrownBy(() -> new McpClientSession(null, transport, Map.of(), Map.of())) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("The requestTimeout can not be null"); - - assertThatThrownBy(() -> new McpClientSession(TIMEOUT, null, Map.of(), Map.of())) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("transport can not be null"); - } - - TypeReference responseType = new TypeReference<>() { - }; - - @Test - void testSendRequest() { - String testParam = "test parameter"; - String responseData = "test response"; - - // Create a Mono that will emit the response after the request is sent - Mono responseMono = session.sendRequest(TEST_METHOD, testParam, responseType); - // Verify response handling - StepVerifier.create(responseMono).then(() -> { - McpSchema.JSONRPCRequest request = transport.getLastSentMessageAsRequest(); - transport.simulateIncomingMessage( - new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), responseData, null)); - }).consumeNextWith(response -> { - // Verify the request was sent - McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessageAsRequest(); - assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCRequest.class); - McpSchema.JSONRPCRequest request = (McpSchema.JSONRPCRequest) sentMessage; - assertThat(request.method()).isEqualTo(TEST_METHOD); - assertThat(request.params()).isEqualTo(testParam); - assertThat(response).isEqualTo(responseData); - }).verifyComplete(); - } - - @Test - void testSendRequestWithError() { - Mono responseMono = session.sendRequest(TEST_METHOD, "test", responseType); - - // Verify error handling - StepVerifier.create(responseMono).then(() -> { - McpSchema.JSONRPCRequest request = transport.getLastSentMessageAsRequest(); - // Simulate error response - McpSchema.JSONRPCResponse.JSONRPCError error = new McpSchema.JSONRPCResponse.JSONRPCError( - McpSchema.ErrorCodes.METHOD_NOT_FOUND, "Method not found", null); - transport.simulateIncomingMessage( - new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, error)); - }).expectError(McpError.class).verify(); - } - - @Test - void testRequestTimeout() { - Mono responseMono = session.sendRequest(TEST_METHOD, "test", responseType); - - // Verify timeout - StepVerifier.create(responseMono) - .expectError(java.util.concurrent.TimeoutException.class) - .verify(TIMEOUT.plusSeconds(1)); - } - - @Test - void testSendNotification() { - Map params = Map.of("key", "value"); - Mono notificationMono = session.sendNotification(TEST_NOTIFICATION, params); - - // Verify notification was sent - StepVerifier.create(notificationMono).consumeSubscriptionWith(response -> { - McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessage(); - assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCNotification.class); - McpSchema.JSONRPCNotification notification = (McpSchema.JSONRPCNotification) sentMessage; - assertThat(notification.method()).isEqualTo(TEST_NOTIFICATION); - assertThat(notification.params()).isEqualTo(params); - }).verifyComplete(); - } - - @Test - void testRequestHandling() { - String echoMessage = "Hello MCP!"; - Map> requestHandlers = Map.of(ECHO_METHOD, - params -> Mono.just(params)); - transport = new MockMcpClientTransport(); - session = new McpClientSession(TIMEOUT, transport, requestHandlers, Map.of()); - - // Simulate incoming request - McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, ECHO_METHOD, - "test-id", echoMessage); - transport.simulateIncomingMessage(request); - - // Verify response - McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessage(); - assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class); - McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; - assertThat(response.result()).isEqualTo(echoMessage); - assertThat(response.error()).isNull(); - } - - @Test - void testNotificationHandling() { - Sinks.One receivedParams = Sinks.one(); - - transport = new MockMcpClientTransport(); - session = new McpClientSession(TIMEOUT, transport, Map.of(), - Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> receivedParams.tryEmitValue(params)))); - - // Simulate incoming notification from the server - Map notificationParams = Map.of("status", "ready"); - - McpSchema.JSONRPCNotification notification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, - TEST_NOTIFICATION, notificationParams); - - transport.simulateIncomingMessage(notification); - - // Verify handler was called - assertThat(receivedParams.asMono().block(Duration.ofSeconds(1))).isEqualTo(notificationParams); - } - - @Test - void testUnknownMethodHandling() { - // Simulate incoming request for unknown method - McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, "unknown.method", - "test-id", null); - transport.simulateIncomingMessage(request); - - // Verify error response - McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessage(); - assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class); - McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; - assertThat(response.error()).isNotNull(); - assertThat(response.error().code()).isEqualTo(McpSchema.ErrorCodes.METHOD_NOT_FOUND); - } - - @Test - void testGracefulShutdown() { - StepVerifier.create(session.closeGracefully()).verifyComplete(); - } - -} diff --git a/pom.xml b/pom.xml index c0b1f7a44..f8bc3a9c2 100644 --- a/pom.xml +++ b/pom.xml @@ -6,7 +6,7 @@ io.modelcontextprotocol.sdk mcp-parent - 0.12.0-SNAPSHOT + 0.18.0-SNAPSHOT pom https://github.com/modelcontextprotocol/java-sdk @@ -59,23 +59,23 @@ 17 - 3.26.3 + 3.27.6 5.10.2 - 5.17.0 - 1.20.4 - 1.17.5 + 5.20.0 + 1.21.4 + 1.17.8 1.21.0 2.0.16 1.5.15 - 2.17.0 + 2.19.2 6.2.1 3.11.0 3.1.2 3.5.2 - 3.5.0 + 3.11.2 3.3.0 0.8.10 1.5.0 @@ -96,13 +96,16 @@ 4.2.0 7.1.0 4.1.0 - 1.5.7 + 2.0.0 mcp-bom mcp + mcp-core + mcp-json-jackson2 + mcp-json mcp-spring/mcp-spring-webflux mcp-spring/mcp-spring-webmvc mcp-test @@ -276,6 +279,7 @@ ${maven-javadoc-plugin.version} false + true false none