diff --git a/README.md b/README.md index 436104c63..7bda15006 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,9 @@ # MCP Java SDK +[![License](https://img.shields.io/badge/License-MIT-blue.svg)](https://opensource.org/license/MIT) [![Build Status](https://github.com/modelcontextprotocol/java-sdk/actions/workflows/publish-snapshot.yml/badge.svg)](https://github.com/modelcontextprotocol/java-sdk/actions/workflows/publish-snapshot.yml) +[![Maven Central](https://img.shields.io/maven-central/v/io.modelcontextprotocol.sdk/mcp.svg?label=Maven%20Central)](https://central.sonatype.com/artifact/io.modelcontextprotocol.sdk/mcp) +[![Java Version](https://img.shields.io/badge/Java-17%2B-orange)](https://www.oracle.com/java/technologies/javase/jdk17-archive-downloads.html) + A set of projects that provide Java SDK integration for the [Model Context Protocol](https://modelcontextprotocol.org/docs/concepts/architecture). This SDK enables Java applications to interact with AI models and tools through a standardized interface, supporting both synchronous and asynchronous communication patterns. @@ -43,6 +47,7 @@ Please follow the [Contributing Guidelines](CONTRIBUTING.md). - Christian Tzolov - Dariusz Jędrzejczyk +- Daniel Garnier-Moiroux ## Links @@ -50,6 +55,133 @@ Please follow the [Contributing Guidelines](CONTRIBUTING.md). - [Issue Tracker](https://github.com/modelcontextprotocol/java-sdk/issues) - [CI/CD](https://github.com/modelcontextprotocol/java-sdk/actions) +## Architecture and Design Decisions + +### Introduction + +Building a general-purpose MCP Java SDK requires making technology decisions in areas where the JDK provides limited or no support. The Java ecosystem is powerful but fragmented: multiple valid approaches exist, each with strong communities. +Our goal is not to prescribe "the one true way," but to provide a reference implementation of the MCP specification that is: + +* **Pragmatic** – makes developers productive quickly +* **Interoperable** – aligns with widely used libraries and practices +* **Pluggable** – allows alternatives where projects prefer different stacks +* **Grounded in team familiarity** – we chose technologies the team can be productive with today, while remaining open to community contributions that broaden the SDK + +### Key Choices and Considerations + +The SDK had to make decisions in the following areas: + +1. **JSON serialization** – mapping between JSON and Java types + +2. **Programming model** – supporting asynchronous processing, cancellation, and streaming while staying simple for blocking use cases + +3. **Observability** – logging and enabling integration with metrics/tracing + +4. **Remote clients and servers** – supporting both consuming MCP servers (client transport) and exposing MCP endpoints (server transport with authorization) + +The following sections explain what we chose, why it made sense, and how the choices align with the SDK's goals. + +### 1. JSON Serialization + +* **SDK Choice**: Jackson for JSON serialization and deserialization, behind an SDK abstraction (`mcp-json`) + +* **Why**: Jackson is widely adopted across the Java ecosystem, provides strong performance and a mature annotation model, and is familiar to the SDK team and many potential contributors. + +* **How we expose it**: Public APIs use a zero-dependency abstraction (`mcp-json`). Jackson is shipped as the default implementation (`mcp-jackson2`), but alternatives can be plugged in. + +* **How it fits the SDK**: This offers a pragmatic default while keeping flexibility for projects that prefer different JSON libraries. + +### 2. Programming Model + +* **SDK Choice**: Reactive Streams for public APIs, with Project Reactor as the internal implementation and a synchronous facade for blocking use cases + +* **Why**: MCP builds on JSON-RPC's asynchronous nature and defines a bidirectional protocol on top of it, enabling asynchronous and streaming interactions. MCP explicitly supports: + + * Multiple in-flight requests and responses + * Notifications that do not expect a reply + * STDIO transports for inter-process communication using pipes + * Streaming transports such as Server-Sent Events and Streamable HTTP + + These requirements call for a programming model more powerful than single-result futures like `CompletableFuture`. + + * **Reactive Streams: the Community Standard** + + Reactive Streams is a small Java specification that standardizes asynchronous stream processing with backpressure. It defines four minimal interfaces (Publisher, Subscriber, Subscription, and Processor). These interfaces are widely recognized as the standard contract for async, non-blocking pipelines in Java. + + * **Reactive Streams Implementation** + + The SDK uses Project Reactor as its implementation of the Reactive Streams specification. Reactor is mature, widely adopted, provides rich operators, and integrates well with observability through context propagation. Team familiarity also allowed us to deliver a solid foundation quickly. + We plan to convert the public API to only expose Reactive Streams interfaces. By defining the public API in terms of Reactive Streams interfaces and using Reactor internally, the SDK stays standards-based while benefiting from a practical, production-ready implementation. + + * **Synchronous Facade in the SDK** + + Not all MCP use cases require streaming pipelines. Many scenarios are as simple as "send a request and block until I get the result." + To support this, the SDK provides a synchronous facade layered on top of the reactive core. Developers can stay in a blocking model when it's enough, while still having access to asynchronous streaming when needed. + +* **How it fits the SDK**: This design balances scalability, approachability, and future evolution such as Virtual Threads and Structured Concurrency in upcoming JDKs. + +### 3. Observability + +* **SDK Choice**: SLF4J for logging; Reactor Context for observability propagation + +* **Why**: SLF4J is the de facto logging facade in Java, with broad compatibility. Reactor Context enables propagation of observability data such as correlation IDs and tracing state across async boundaries. This ensures interoperability with modern observability frameworks. + +* **How we expose it**: Public APIs log through SLF4J only, with no backend included. Observability metadata flows through Reactor pipelines. The SDK itself does not ship metrics or tracing implementations. + +* **How it fits the SDK**: This provides reliable logging by default and seamless integration with Micrometer, OpenTelemetry, or similar systems for metrics and tracing. + +### 4. Remote MCP Clients and Servers + +MCP supports both clients (applications consuming MCP servers) and servers (applications exposing MCP endpoints). The SDK provides support for both sides. + +#### Client Transport in the SDK + +* **SDK Choice**: JDK HttpClient (Java 11+) as the default client, with optional Spring WebClient support + +* **Why**: The JDK HttpClient is built-in, portable, and supports streaming responses. This keeps the default lightweight with no extra dependencies. Spring WebClient support is available for Spring-based projects. + +* **How we expose it**: MCP Client APIs are transport-agnostic. The core module ships with JDK HttpClient transport. A Spring module provides WebClient integration. + +* **How it fits the SDK**: This ensures all applications can talk to MCP servers out of the box, while allowing richer integration in Spring and other environments. + +#### Server Transport in the SDK + +* **SDK Choice**: Jakarta Servlet implementation in core, with optional Spring WebFlux and Spring WebMVC providers + +* **Why**: Servlet is the most widely deployed Java server API. WebFlux and WebMVC cover a significant part of the Spring community. Together these provide reach across blocking and non-blocking models. + +* **How we expose it**: Server APIs are transport-agnostic. Core includes Servlet support. Spring modules extend support for WebFlux and WebMVC. + +* **How it fits the SDK**: This allows developers to expose MCP servers in the most common Java environments today, while enabling other transport implementations such as Netty, Vert.x, or Helidon. + +#### Authorization in the SDK + +* **SDK Choice**: Pluggable authorization hooks for MCP servers; no built-in implementation + +* **Why**: MCP servers must restrict access to authenticated and authorized clients. Authorization needs differ across environments such as Spring Security, MicroProfile JWT, or custom solutions. Providing hooks avoids lock-in and leverages proven libraries. + +* **How we expose it**: Authorization is integrated into the server transport layer. The SDK does not include its own authorization system. + +* **How it fits the SDK**: This keeps server-side security ecosystem-neutral, while ensuring applications can plug in their preferred authorization strategy. + +### Project Structure of the SDK + +The SDK is organized into modules to separate concerns and allow adopters to bring in only what they need: +* `mcp-bom` – Dependency versions +* `mcp-core` – Reference implementation (STDIO, JDK HttpClient, Servlet) +* `mcp-json` – JSON abstraction +* `mcp-jackson2` – Jackson implementation of JSON binding +* `mcp` – Convenience bundle (core + Jackson) +* `mcp-test` – Shared testing utilities +* `mcp-spring` – Spring integrations (WebClient, WebFlux, WebMVC) + +For example, a minimal adopter may depend only on `mcp` (core + Jackson), while a Spring-based application can use `mcp-spring` for deeper framework integration. + +### Future Directions + +The SDK is designed to evolve with the Java ecosystem. Areas we are actively watching include: +Concurrency in the JDK – Virtual Threads and Structured Concurrency may simplify the synchronous API story + ## License This project is licensed under the [MIT License](LICENSE). diff --git a/mcp-bom/pom.xml b/mcp-bom/pom.xml index 7214dacda..447c9e0bd 100644 --- a/mcp-bom/pom.xml +++ b/mcp-bom/pom.xml @@ -7,7 +7,7 @@ io.modelcontextprotocol.sdk mcp-parent - 0.11.0-SNAPSHOT + 0.18.0-SNAPSHOT mcp-bom @@ -27,12 +27,33 @@ + + io.modelcontextprotocol.sdk + mcp-core + ${project.version} + + + io.modelcontextprotocol.sdk mcp ${project.version} + + + io.modelcontextprotocol.sdk + mcp-json + ${project.version} + + + + + io.modelcontextprotocol.sdk + mcp-json-jackson2 + ${project.version} + + io.modelcontextprotocol.sdk diff --git a/mcp-core/pom.xml b/mcp-core/pom.xml new file mode 100644 index 000000000..9e23ffd79 --- /dev/null +++ b/mcp-core/pom.xml @@ -0,0 +1,235 @@ + + + 4.0.0 + + io.modelcontextprotocol.sdk + mcp-parent + 0.18.0-SNAPSHOT + + mcp-core + jar + Java MCP SDK Core + Core classes of the Java SDK implementation of the Model Context Protocol, enabling seamless integration with language models and AI tools + https://github.com/modelcontextprotocol/java-sdk + + + https://github.com/modelcontextprotocol/java-sdk + git://github.com/modelcontextprotocol/java-sdk.git + git@github.com/modelcontextprotocol/java-sdk.git + + + + + + biz.aQute.bnd + bnd-maven-plugin + ${bnd-maven-plugin.version} + + + bnd-process + + bnd-process + + + + + + + + + + + org.apache.maven.plugins + maven-jar-plugin + + + ${project.build.outputDirectory}/META-INF/MANIFEST.MF + + + + + + + + + io.modelcontextprotocol.sdk + mcp-json + 0.18.0-SNAPSHOT + + + + org.slf4j + slf4j-api + ${slf4j-api.version} + + + + com.fasterxml.jackson.core + jackson-annotations + ${jackson.version} + + + + io.projectreactor + reactor-core + + + + + + jakarta.servlet + jakarta.servlet-api + ${jakarta.servlet.version} + provided + + + + + io.modelcontextprotocol.sdk + mcp-json-jackson2 + 0.18.0-SNAPSHOT + test + + + + org.springframework + spring-webmvc + ${springframework.version} + test + + + + + io.projectreactor.netty + reactor-netty-http + test + + + + + org.springframework + spring-context + ${springframework.version} + test + + + + org.springframework + spring-test + ${springframework.version} + test + + + + org.assertj + assertj-core + ${assert4j.version} + test + + + org.junit.jupiter + junit-jupiter-api + ${junit.version} + test + + + org.junit.jupiter + junit-jupiter-params + ${junit.version} + test + + + org.mockito + mockito-core + ${mockito.version} + test + + + + + net.bytebuddy + byte-buddy + ${byte-buddy.version} + test + + + io.projectreactor + reactor-test + test + + + org.testcontainers + junit-jupiter + ${testcontainers.version} + test + + + + org.awaitility + awaitility + ${awaitility.version} + test + + + + ch.qos.logback + logback-classic + ${logback.version} + test + + + + net.javacrumbs.json-unit + json-unit-assertj + ${json-unit-assertj.version} + test + + + + + org.apache.tomcat.embed + tomcat-embed-core + ${tomcat.version} + test + + + org.apache.tomcat.embed + tomcat-embed-websocket + ${tomcat.version} + test + + + + org.testcontainers + toxiproxy + ${toxiproxy.version} + test + + + + + com.google.code.gson + gson + 2.10.1 + test + + + + + diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/LifecycleInitializer.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/LifecycleInitializer.java new file mode 100644 index 000000000..07d86f40e --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/LifecycleInitializer.java @@ -0,0 +1,358 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.client; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; + +import io.modelcontextprotocol.spec.McpClientSession; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException; +import io.modelcontextprotocol.util.Assert; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.util.context.ContextView; + +/** + * Handles the protocol initialization phase between client and server + * + *

+ * The initialization phase MUST be the first interaction between client and server. + * During this phase, the client and server perform the following operations: + *

    + *
  • Establish protocol version compatibility
  • + *
  • Exchange and negotiate capabilities
  • + *
  • Share implementation details
  • + *
+ * + * Client Initialization Process + *

+ * The client MUST initiate this phase by sending an initialize request containing: + *

    + *
  • Protocol version supported
  • + *
  • Client capabilities
  • + *
  • Client implementation information
  • + *
+ * + *

+ * After successful initialization, the client MUST send an initialized notification to + * indicate it is ready to begin normal operations. + * + * Server Response + *

+ * The server MUST respond with its own capabilities and information. + * + * Protocol Version Negotiation + *

+ * In the initialize request, the client MUST send a protocol version it supports. This + * SHOULD be the latest version supported by the client. + * + *

+ * If the server supports the requested protocol version, it MUST respond with the same + * version. Otherwise, the server MUST respond with another protocol version it supports. + * This SHOULD be the latest version supported by the server. + * + *

+ * If the client does not support the version in the server's response, it SHOULD + * disconnect. + * + * Request Restrictions + *

+ * Important: The following restrictions apply during initialization: + *

    + *
  • The client SHOULD NOT send requests other than pings before the server has + * responded to the initialize request
  • + *
  • The server SHOULD NOT send requests other than pings and logging before receiving + * the initialized notification
  • + *
+ */ +class LifecycleInitializer { + + private static final Logger logger = LoggerFactory.getLogger(LifecycleInitializer.class); + + /** + * The MCP session supplier that manages bidirectional JSON-RPC communication between + * clients and servers. + */ + private final Function sessionSupplier; + + private final McpSchema.ClientCapabilities clientCapabilities; + + private final McpSchema.Implementation clientInfo; + + private List protocolVersions; + + private final AtomicReference initializationRef = new AtomicReference<>(); + + /** + * The max timeout to await for the client-server connection to be initialized. + */ + private final Duration initializationTimeout; + + /** + * Post-initialization hook to perform additional operations after every successful + * initialization. + */ + private final Function> postInitializationHook; + + public LifecycleInitializer(McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo, + List protocolVersions, Duration initializationTimeout, + Function sessionSupplier, + Function> postInitializationHook) { + + Assert.notNull(sessionSupplier, "Session supplier must not be null"); + Assert.notNull(clientCapabilities, "Client capabilities must not be null"); + Assert.notNull(clientInfo, "Client info must not be null"); + Assert.notEmpty(protocolVersions, "Protocol versions must not be empty"); + Assert.notNull(initializationTimeout, "Initialization timeout must not be null"); + Assert.notNull(postInitializationHook, "Post-initialization hook must not be null"); + + this.sessionSupplier = sessionSupplier; + this.clientCapabilities = clientCapabilities; + this.clientInfo = clientInfo; + this.protocolVersions = Collections.unmodifiableList(new ArrayList<>(protocolVersions)); + this.initializationTimeout = initializationTimeout; + this.postInitializationHook = postInitializationHook; + } + + /** + * This method is package-private and used for test only. Should not be called by user + * code. + * @param protocolVersions the Client supported protocol versions. + */ + void setProtocolVersions(List protocolVersions) { + this.protocolVersions = protocolVersions; + } + + /** + * Represents the initialization state of the MCP client. + */ + interface Initialization { + + /** + * Returns the MCP client session that is used to communicate with the server. + * This session is established during the initialization process and is used for + * sending requests and notifications. + * @return The MCP client session + */ + McpClientSession mcpSession(); + + /** + * Returns the result of the MCP initialization process. This result contains + * information about the protocol version, capabilities, server info, and + * instructions provided by the server during the initialization phase. + * @return The result of the MCP initialization process + */ + McpSchema.InitializeResult initializeResult(); + + } + + private static class DefaultInitialization implements Initialization { + + /** + * A sink that emits the result of the MCP initialization process. It allows + * subscribers to wait for the initialization to complete. + */ + private final Sinks.One initSink; + + /** + * Holds the result of the MCP initialization process. It is used to cache the + * result for future requests. + */ + private final AtomicReference result; + + /** + * Holds the MCP client session that is used to communicate with the server. It is + * set during the initialization process and used for sending requests and + * notifications. + */ + private final AtomicReference mcpClientSession; + + private DefaultInitialization() { + this.initSink = Sinks.one(); + this.result = new AtomicReference<>(); + this.mcpClientSession = new AtomicReference<>(); + } + + // --------------------------------------------------- + // Public access for mcpSession and initializeResult because they are + // used in by the McpAsyncClient. + // ---------------------------------------------------- + public McpClientSession mcpSession() { + return this.mcpClientSession.get(); + } + + public McpSchema.InitializeResult initializeResult() { + return this.result.get(); + } + + // --------------------------------------------------- + // Private accessors used internally by the LifecycleInitializer to set the MCP + // client session and complete the initialization process. + // --------------------------------------------------- + private void setMcpClientSession(McpClientSession mcpClientSession) { + this.mcpClientSession.set(mcpClientSession); + } + + private Mono await() { + return this.initSink.asMono(); + } + + private void complete(McpSchema.InitializeResult initializeResult) { + // inform all the subscribers waiting for the initialization + this.initSink.emitValue(initializeResult, Sinks.EmitFailureHandler.FAIL_FAST); + } + + private void cacheResult(McpSchema.InitializeResult initializeResult) { + // first ensure the result is cached + this.result.set(initializeResult); + } + + private void error(Throwable t) { + this.initSink.emitError(t, Sinks.EmitFailureHandler.FAIL_FAST); + } + + private void close() { + this.mcpSession().close(); + } + + private Mono closeGracefully() { + return this.mcpSession().closeGracefully(); + } + + } + + public boolean isInitialized() { + return this.currentInitializationResult() != null; + } + + public McpSchema.InitializeResult currentInitializationResult() { + DefaultInitialization current = this.initializationRef.get(); + McpSchema.InitializeResult initializeResult = current != null ? current.result.get() : null; + return initializeResult; + } + + /** + * Hook to handle exceptions that occur during the MCP transport session. + *

+ * If the exception is a {@link McpTransportSessionNotFoundException}, it indicates + * that the session was not found, and we should re-initialize the client. + *

+ * @param t The exception to handle + */ + public void handleException(Throwable t) { + logger.warn("Handling exception", t); + if (t instanceof McpTransportSessionNotFoundException) { + DefaultInitialization previous = this.initializationRef.getAndSet(null); + if (previous != null) { + previous.close(); + } + // Providing an empty operation since we are only interested in triggering + // the implicit initialization step. + this.withInitialization("re-initializing", result -> Mono.empty()).subscribe(); + } + } + + /** + * Utility method to ensure the initialization is established before executing an + * operation. + * @param The type of the result Mono + * @param actionName The action to perform when the client is initialized + * @param operation The operation to execute when the client is initialized + * @return A Mono that completes with the result of the operation + */ + public Mono withInitialization(String actionName, Function> operation) { + return Mono.deferContextual(ctx -> { + DefaultInitialization newInit = new DefaultInitialization(); + DefaultInitialization previous = this.initializationRef.compareAndExchange(null, newInit); + + boolean needsToInitialize = previous == null; + logger.debug(needsToInitialize ? "Initialization process started" : "Joining previous initialization"); + + Mono initializationJob = needsToInitialize + ? this.doInitialize(newInit, this.postInitializationHook, ctx) : previous.await(); + + return initializationJob.map(initializeResult -> this.initializationRef.get()) + .timeout(this.initializationTimeout) + .onErrorResume(ex -> { + this.initializationRef.compareAndSet(newInit, null); + return Mono.error(new RuntimeException("Client failed to initialize " + actionName, ex)); + }) + .flatMap(res -> operation.apply(res) + .contextWrite(c -> c.put(McpAsyncClient.NEGOTIATED_PROTOCOL_VERSION, + res.initializeResult().protocolVersion()))); + }); + } + + private Mono doInitialize(DefaultInitialization initialization, + Function> postInitOperation, ContextView ctx) { + + initialization.setMcpClientSession(this.sessionSupplier.apply(ctx)); + + McpClientSession mcpClientSession = initialization.mcpSession(); + + String latestVersion = this.protocolVersions.get(this.protocolVersions.size() - 1); + + McpSchema.InitializeRequest initializeRequest = new McpSchema.InitializeRequest(latestVersion, + this.clientCapabilities, this.clientInfo); + + Mono result = mcpClientSession.sendRequest(McpSchema.METHOD_INITIALIZE, + initializeRequest, McpAsyncClient.INITIALIZE_RESULT_TYPE_REF); + + return result.flatMap(initializeResult -> { + logger.info("Server response with Protocol: {}, Capabilities: {}, Info: {} and Instructions {}", + initializeResult.protocolVersion(), initializeResult.capabilities(), initializeResult.serverInfo(), + initializeResult.instructions()); + + if (!this.protocolVersions.contains(initializeResult.protocolVersion())) { + return Mono.error(McpError.builder(-32602) + .message("Unsupported protocol version") + .data("Unsupported protocol version from the server: " + initializeResult.protocolVersion()) + .build()); + } + + return mcpClientSession.sendNotification(McpSchema.METHOD_NOTIFICATION_INITIALIZED, null) + .contextWrite( + c -> c.put(McpAsyncClient.NEGOTIATED_PROTOCOL_VERSION, initializeResult.protocolVersion())) + .thenReturn(initializeResult); + }).flatMap(initializeResult -> { + initialization.cacheResult(initializeResult); + return postInitOperation.apply(initialization).thenReturn(initializeResult); + }).doOnNext(initialization::complete).onErrorResume(ex -> { + initialization.error(ex); + return Mono.error(ex); + }); + } + + /** + * Closes the current initialization if it exists. + */ + public void close() { + DefaultInitialization current = this.initializationRef.getAndSet(null); + if (current != null) { + current.close(); + } + } + + /** + * Gracefully closes the current initialization if it exists. + * @return A Mono that completes when the connection is closed + */ + public Mono closeGracefully() { + return Mono.defer(() -> { + DefaultInitialization current = this.initializationRef.getAndSet(null); + Mono sessionClose = current != null ? current.closeGracefully() : Mono.empty(); + return sessionClose; + }); + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java similarity index 69% rename from mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java rename to mcp-core/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java index 617cec175..e6a09cd08 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java @@ -1,27 +1,27 @@ /* * Copyright 2024-2024 the original author or authors. */ + package io.modelcontextprotocol.client; import java.time.Duration; +import java.time.LocalDateTime; +import java.time.format.DateTimeFormatter; import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import com.fasterxml.jackson.core.type.TypeReference; - +import io.modelcontextprotocol.client.LifecycleInitializer.Initialization; +import io.modelcontextprotocol.json.TypeRef; +import io.modelcontextprotocol.json.schema.JsonSchemaValidator; import io.modelcontextprotocol.spec.McpClientSession; import io.modelcontextprotocol.spec.McpClientSession.NotificationHandler; import io.modelcontextprotocol.spec.McpClientSession.RequestHandler; import io.modelcontextprotocol.spec.McpClientTransport; -import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; @@ -35,13 +35,12 @@ import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; import io.modelcontextprotocol.spec.McpSchema.PaginatedRequest; import io.modelcontextprotocol.spec.McpSchema.Root; -import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException; import io.modelcontextprotocol.util.Assert; import io.modelcontextprotocol.util.Utils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -import reactor.core.publisher.Sinks; -import reactor.util.context.ContextView; /** * The Model Context Protocol (MCP) client implementation that provides asynchronous @@ -76,6 +75,7 @@ * @author Dariusz Jędrzejczyk * @author Christian Tzolov * @author Jihoon Kim + * @author Anurag Pant * @see McpClient * @see McpSchema * @see McpClientSession @@ -85,30 +85,28 @@ public class McpAsyncClient { private static final Logger logger = LoggerFactory.getLogger(McpAsyncClient.class); - private static final TypeReference VOID_TYPE_REFERENCE = new TypeReference<>() { + private static final TypeRef VOID_TYPE_REFERENCE = new TypeRef<>() { }; - public static final TypeReference OBJECT_TYPE_REF = new TypeReference<>() { + public static final TypeRef OBJECT_TYPE_REF = new TypeRef<>() { }; - public static final TypeReference PAGINATED_REQUEST_TYPE_REF = new TypeReference<>() { + public static final TypeRef PAGINATED_REQUEST_TYPE_REF = new TypeRef<>() { }; - public static final TypeReference INITIALIZE_RESULT_TYPE_REF = new TypeReference<>() { + public static final TypeRef INITIALIZE_RESULT_TYPE_REF = new TypeRef<>() { }; - public static final TypeReference CREATE_MESSAGE_REQUEST_TYPE_REF = new TypeReference<>() { + public static final TypeRef CREATE_MESSAGE_REQUEST_TYPE_REF = new TypeRef<>() { }; - public static final TypeReference LOGGING_MESSAGE_NOTIFICATION_TYPE_REF = new TypeReference<>() { + public static final TypeRef LOGGING_MESSAGE_NOTIFICATION_TYPE_REF = new TypeRef<>() { }; - private final AtomicReference initializationRef = new AtomicReference<>(); + public static final TypeRef PROGRESS_NOTIFICATION_TYPE_REF = new TypeRef<>() { + }; - /** - * The max timeout to await for the client-server connection to be initialized. - */ - private final Duration initializationTimeout; + public static final String NEGOTIATED_PROTOCOL_VERSION = "io.modelcontextprotocol.client.negotiated-protocol-version"; /** * Client capabilities. @@ -153,15 +151,24 @@ public class McpAsyncClient { private final McpClientTransport transport; /** - * Supported protocol versions. + * The lifecycle initializer that manages the client-server connection initialization. + */ + private final LifecycleInitializer initializer; + + /** + * JSON schema validator to use for validating tool responses against output schemas. + */ + private final JsonSchemaValidator jsonSchemaValidator; + + /** + * Cached tool output schemas. */ - private List protocolVersions = List.of(McpSchema.LATEST_PROTOCOL_VERSION); + private final ConcurrentHashMap> toolsOutputSchemaCache; /** - * The MCP session supplier that manages bidirectional JSON-RPC communication between - * clients and servers. + * Whether to enable automatic schema caching during callTool operations. */ - private final Function sessionSupplier; + private final boolean enableCallToolSchemaCaching; /** * Create a new McpAsyncClient with the given transport and session request-response @@ -169,10 +176,12 @@ public class McpAsyncClient { * @param transport the transport to use. * @param requestTimeout the session request-response timeout. * @param initializationTimeout the max timeout to await for the client-server - * @param features the MCP Client supported features. + * @param jsonSchemaValidator the JSON schema validator to use for validating tool + * @param features the MCP Client supported features. responses against output + * schemas. */ McpAsyncClient(McpClientTransport transport, Duration requestTimeout, Duration initializationTimeout, - McpClientFeatures.Async features) { + JsonSchemaValidator jsonSchemaValidator, McpClientFeatures.Async features) { Assert.notNull(transport, "Transport must not be null"); Assert.notNull(requestTimeout, "Request timeout must not be null"); @@ -182,11 +191,19 @@ public class McpAsyncClient { this.clientCapabilities = features.clientCapabilities(); this.transport = transport; this.roots = new ConcurrentHashMap<>(features.roots()); - this.initializationTimeout = initializationTimeout; + this.jsonSchemaValidator = jsonSchemaValidator; + this.toolsOutputSchemaCache = new ConcurrentHashMap<>(); + this.enableCallToolSchemaCaching = features.enableCallToolSchemaCaching(); // Request Handlers Map> requestHandlers = new HashMap<>(); + // Ping MUST respond with an empty data, but not NULL response. + requestHandlers.put(McpSchema.METHOD_PING, params -> { + logger.debug("Received ping: {}", LocalDateTime.now().format(DateTimeFormatter.ISO_LOCAL_DATE_TIME)); + return Mono.just(Map.of()); + }); + // Roots List Request Handler if (this.clientCapabilities.roots() != null) { requestHandlers.put(McpSchema.METHOD_ROOTS_LIST, rootsListRequestHandler()); @@ -195,7 +212,8 @@ public class McpAsyncClient { // Sampling Handler if (this.clientCapabilities.sampling() != null) { if (features.samplingHandler() == null) { - throw new McpError("Sampling handler must not be null when client capabilities include sampling"); + throw new IllegalArgumentException( + "Sampling handler must not be null when client capabilities include sampling"); } this.samplingHandler = features.samplingHandler(); requestHandlers.put(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE, samplingCreateMessageHandler()); @@ -204,7 +222,8 @@ public class McpAsyncClient { // Elicitation Handler if (this.clientCapabilities.elicitation() != null) { if (features.elicitationHandler() == null) { - throw new McpError("Elicitation handler must not be null when client capabilities include elicitation"); + throw new IllegalArgumentException( + "Elicitation handler must not be null when client capabilities include elicitation"); } this.elicitationHandler = features.elicitationHandler(); requestHandlers.put(McpSchema.METHOD_ELICITATION_CREATE, elicitationCreateHandler()); @@ -267,28 +286,49 @@ public class McpAsyncClient { notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_MESSAGE, asyncLoggingNotificationHandler(loggingConsumersFinal)); - this.transport.setExceptionHandler(this::handleException); - this.sessionSupplier = ctx -> new McpClientSession(requestTimeout, transport, requestHandlers, - notificationHandlers, con -> con.contextWrite(ctx)); - } + // Utility Progress Notification + List>> progressConsumersFinal = new ArrayList<>(); + progressConsumersFinal + .add((notification) -> Mono.fromRunnable(() -> logger.debug("Progress: {}", notification))); + if (!Utils.isEmpty(features.progressConsumers())) { + progressConsumersFinal.addAll(features.progressConsumers()); + } + notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_PROGRESS, + asyncProgressNotificationHandler(progressConsumersFinal)); + + Function> postInitializationHook = init -> { - private void handleException(Throwable t) { - logger.warn("Handling exception", t); - if (t instanceof McpTransportSessionNotFoundException) { - Initialization previous = this.initializationRef.getAndSet(null); - if (previous != null) { - previous.close(); + if (init.initializeResult().capabilities().tools() == null || !enableCallToolSchemaCaching) { + return Mono.empty(); } - // Providing an empty operation since we are only interested in triggering the - // implicit initialization step. - withSession("re-initializing", result -> Mono.empty()).subscribe(); - } + + return this.listToolsInternal(init, McpSchema.FIRST_PAGE).doOnNext(listToolsResult -> { + listToolsResult.tools() + .forEach(tool -> logger.debug("Tool {} schema: {}", tool.name(), tool.outputSchema())); + if (enableCallToolSchemaCaching && listToolsResult.tools() != null) { + // Cache tools output schema + listToolsResult.tools() + .stream() + .filter(tool -> tool.outputSchema() != null) + .forEach(tool -> this.toolsOutputSchemaCache.put(tool.name(), tool.outputSchema())); + } + }).then(); + }; + + this.initializer = new LifecycleInitializer(clientCapabilities, clientInfo, transport.protocolVersions(), + initializationTimeout, ctx -> new McpClientSession(requestTimeout, transport, requestHandlers, + notificationHandlers, con -> con.contextWrite(ctx)), + postInitializationHook); + + this.transport.setExceptionHandler(this.initializer::handleException); } - private McpSchema.InitializeResult currentInitializationResult() { - Initialization current = this.initializationRef.get(); - McpSchema.InitializeResult initializeResult = current != null ? current.result.get() : null; - return initializeResult; + /** + * Get the current initialization result. + * @return the initialization result. + */ + public McpSchema.InitializeResult getCurrentInitializationResult() { + return this.initializer.currentInitializationResult(); } /** @@ -296,7 +336,7 @@ private McpSchema.InitializeResult currentInitializationResult() { * @return The server capabilities */ public McpSchema.ServerCapabilities getServerCapabilities() { - McpSchema.InitializeResult initializeResult = currentInitializationResult(); + McpSchema.InitializeResult initializeResult = this.initializer.currentInitializationResult(); return initializeResult != null ? initializeResult.capabilities() : null; } @@ -306,7 +346,7 @@ public McpSchema.ServerCapabilities getServerCapabilities() { * @return The server instructions */ public String getServerInstructions() { - McpSchema.InitializeResult initializeResult = currentInitializationResult(); + McpSchema.InitializeResult initializeResult = this.initializer.currentInitializationResult(); return initializeResult != null ? initializeResult.instructions() : null; } @@ -315,7 +355,7 @@ public String getServerInstructions() { * @return The server implementation details */ public McpSchema.Implementation getServerInfo() { - McpSchema.InitializeResult initializeResult = currentInitializationResult(); + McpSchema.InitializeResult initializeResult = this.initializer.currentInitializationResult(); return initializeResult != null ? initializeResult.serverInfo() : null; } @@ -324,8 +364,7 @@ public McpSchema.Implementation getServerInfo() { * @return true if the client-server connection is initialized */ public boolean isInitialized() { - Initialization current = this.initializationRef.get(); - return current != null && (current.result.get() != null); + return this.initializer.isInitialized(); } /** @@ -348,10 +387,7 @@ public McpSchema.Implementation getClientInfo() { * Closes the client connection immediately. */ public void close() { - Initialization current = this.initializationRef.getAndSet(null); - if (current != null) { - current.close(); - } + this.initializer.close(); this.transport.close(); } @@ -361,15 +397,14 @@ public void close() { */ public Mono closeGracefully() { return Mono.defer(() -> { - Initialization current = this.initializationRef.getAndSet(null); - Mono sessionClose = current != null ? current.closeGracefully() : Mono.empty(); - return sessionClose.then(transport.closeGracefully()); + return this.initializer.closeGracefully().then(transport.closeGracefully()); }); } // -------------------------- // Initialization // -------------------------- + /** * The initialization phase should be the first interaction between client and server. * The client will ensure it happens in case it has not been explicitly called and in @@ -397,118 +432,7 @@ public Mono closeGracefully() { *

*/ public Mono initialize() { - return withSession("by explicit API call", init -> Mono.just(init.get())); - } - - private Mono doInitialize(Initialization initialization, ContextView ctx) { - initialization.setMcpClientSession(this.sessionSupplier.apply(ctx)); - - McpClientSession mcpClientSession = initialization.mcpSession(); - - String latestVersion = this.protocolVersions.get(this.protocolVersions.size() - 1); - - McpSchema.InitializeRequest initializeRequest = new McpSchema.InitializeRequest(// @formatter:off - latestVersion, - this.clientCapabilities, - this.clientInfo); // @formatter:on - - Mono result = mcpClientSession.sendRequest(McpSchema.METHOD_INITIALIZE, - initializeRequest, INITIALIZE_RESULT_TYPE_REF); - - return result.flatMap(initializeResult -> { - logger.info("Server response with Protocol: {}, Capabilities: {}, Info: {} and Instructions {}", - initializeResult.protocolVersion(), initializeResult.capabilities(), initializeResult.serverInfo(), - initializeResult.instructions()); - - if (!this.protocolVersions.contains(initializeResult.protocolVersion())) { - return Mono.error(new McpError( - "Unsupported protocol version from the server: " + initializeResult.protocolVersion())); - } - - return mcpClientSession.sendNotification(McpSchema.METHOD_NOTIFICATION_INITIALIZED, null) - .thenReturn(initializeResult); - }).doOnNext(initialization::complete).onErrorResume(ex -> { - initialization.error(ex); - return Mono.error(ex); - }); - } - - private static class Initialization { - - private final Sinks.One initSink = Sinks.one(); - - private final AtomicReference result = new AtomicReference<>(); - - private final AtomicReference mcpClientSession = new AtomicReference<>(); - - static Initialization create() { - return new Initialization(); - } - - void setMcpClientSession(McpClientSession mcpClientSession) { - this.mcpClientSession.set(mcpClientSession); - } - - McpClientSession mcpSession() { - return this.mcpClientSession.get(); - } - - McpSchema.InitializeResult get() { - return this.result.get(); - } - - Mono await() { - return this.initSink.asMono(); - } - - void complete(McpSchema.InitializeResult initializeResult) { - // first ensure the result is cached - this.result.set(initializeResult); - // inform all the subscribers waiting for the initialization - this.initSink.emitValue(initializeResult, Sinks.EmitFailureHandler.FAIL_FAST); - } - - void error(Throwable t) { - this.initSink.emitError(t, Sinks.EmitFailureHandler.FAIL_FAST); - } - - void close() { - this.mcpSession().close(); - } - - Mono closeGracefully() { - return this.mcpSession().closeGracefully(); - } - - } - - /** - * Utility method to handle the common pattern of ensuring initialization before - * executing an operation. - * @param The type of the result Mono - * @param actionName The action to perform when the client is initialized - * @param operation The operation to execute when the client is initialized - * @return A Mono that completes with the result of the operation - */ - private Mono withSession(String actionName, Function> operation) { - return Mono.deferContextual(ctx -> { - Initialization newInit = Initialization.create(); - Initialization previous = this.initializationRef.compareAndExchange(null, newInit); - - boolean needsToInitialize = previous == null; - logger.debug(needsToInitialize ? "Initialization process started" : "Joining previous initialization"); - - Mono initializationJob = needsToInitialize ? doInitialize(newInit, ctx) - : previous.await(); - - return initializationJob.map(initializeResult -> this.initializationRef.get()) - .timeout(this.initializationTimeout) - .onErrorResume(ex -> { - logger.warn("Failed to initialize", ex); - return Mono.error(new McpError("Client failed to initialize " + actionName)); - }) - .flatMap(operation); - }); + return this.initializer.withInitialization("by explicit API call", init -> Mono.just(init.initializeResult())); } // -------------------------- @@ -520,13 +444,14 @@ private Mono withSession(String actionName, Function ping() { - return this.withSession("pinging the server", + return this.initializer.withInitialization("pinging the server", init -> init.mcpSession().sendRequest(McpSchema.METHOD_PING, null, OBJECT_TYPE_REF)); } // -------------------------- // Roots // -------------------------- + /** * Adds a new root to the client's root list. * @param root The root to add. @@ -535,15 +460,15 @@ public Mono ping() { public Mono addRoot(Root root) { if (root == null) { - return Mono.error(new McpError("Root must not be null")); + return Mono.error(new IllegalArgumentException("Root must not be null")); } if (this.clientCapabilities.roots() == null) { - return Mono.error(new McpError("Client must be configured with roots capabilities")); + return Mono.error(new IllegalStateException("Client must be configured with roots capabilities")); } if (this.roots.containsKey(root.uri())) { - return Mono.error(new McpError("Root with uri '" + root.uri() + "' already exists")); + return Mono.error(new IllegalStateException("Root with uri '" + root.uri() + "' already exists")); } this.roots.put(root.uri(), root); @@ -569,11 +494,11 @@ public Mono addRoot(Root root) { public Mono removeRoot(String rootUri) { if (rootUri == null) { - return Mono.error(new McpError("Root uri must not be null")); + return Mono.error(new IllegalArgumentException("Root uri must not be null")); } if (this.clientCapabilities.roots() == null) { - return Mono.error(new McpError("Client must be configured with roots capabilities")); + return Mono.error(new IllegalStateException("Client must be configured with roots capabilities")); } Root removed = this.roots.remove(rootUri); @@ -591,7 +516,7 @@ public Mono removeRoot(String rootUri) { } return Mono.empty(); } - return Mono.error(new McpError("Root with uri '" + rootUri + "' not found")); + return Mono.error(new IllegalStateException("Root with uri '" + rootUri + "' not found")); } /** @@ -601,7 +526,7 @@ public Mono removeRoot(String rootUri) { * @return A Mono that completes when the notification is sent. */ public Mono rootsListChangedNotification() { - return this.withSession("sending roots list changed notification", + return this.initializer.withInitialization("sending roots list changed notification", init -> init.mcpSession().sendNotification(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED)); } @@ -632,7 +557,7 @@ private RequestHandler samplingCreateMessageHandler() { // -------------------------- private RequestHandler elicitationCreateHandler() { return params -> { - ElicitRequest request = transport.unmarshalFrom(params, new TypeReference<>() { + ElicitRequest request = transport.unmarshalFrom(params, new TypeRef<>() { }); return this.elicitationHandler.apply(request); @@ -642,10 +567,10 @@ private RequestHandler elicitationCreateHandler() { // -------------------------- // Tools // -------------------------- - private static final TypeReference CALL_TOOL_RESULT_TYPE_REF = new TypeReference<>() { + private static final TypeRef CALL_TOOL_RESULT_TYPE_REF = new TypeRef<>() { }; - private static final TypeReference LIST_TOOLS_RESULT_TYPE_REF = new TypeReference<>() { + private static final TypeRef LIST_TOOLS_RESULT_TYPE_REF = new TypeRef<>() { }; /** @@ -660,29 +585,57 @@ private RequestHandler elicitationCreateHandler() { * @see #listTools() */ public Mono callTool(McpSchema.CallToolRequest callToolRequest) { - return this.withSession("calling tools", init -> { - if (init.get().capabilities().tools() == null) { - return Mono.error(new McpError("Server does not provide tools capability")); + return this.initializer.withInitialization("calling tool", init -> { + if (init.initializeResult().capabilities().tools() == null) { + return Mono.error(new IllegalStateException("Server does not provide tools capability")); } + return init.mcpSession() - .sendRequest(McpSchema.METHOD_TOOLS_CALL, callToolRequest, CALL_TOOL_RESULT_TYPE_REF); + .sendRequest(McpSchema.METHOD_TOOLS_CALL, callToolRequest, CALL_TOOL_RESULT_TYPE_REF) + .flatMap(result -> Mono.just(validateToolResult(callToolRequest.name(), result))); }); } + private McpSchema.CallToolResult validateToolResult(String toolName, McpSchema.CallToolResult result) { + + if (!this.enableCallToolSchemaCaching || result == null || result.isError() == Boolean.TRUE) { + // if tool schema caching is disabled or tool call resulted in an error - skip + // validation and return the result as it is + return result; + } + + Map optOutputSchema = this.toolsOutputSchemaCache.get(toolName); + + if (optOutputSchema == null) { + logger.warn( + "Calling a tool with no outputSchema is not expected to return result with structured content, but got: {}", + result.structuredContent()); + return result; + } + + // Validate the tool output against the cached output schema + var validation = this.jsonSchemaValidator.validate(optOutputSchema, result.structuredContent()); + + if (!validation.valid()) { + logger.warn("Tool call result validation failed: {}", validation.errorMessage()); + throw new IllegalArgumentException("Tool call result validation failed: " + validation.errorMessage()); + } + + return result; + } + /** * Retrieves the list of all tools provided by the server. * @return A Mono that emits the list of all tools result */ public Mono listTools() { return this.listTools(McpSchema.FIRST_PAGE).expand(result -> { - if (result.nextCursor() != null) { - return this.listTools(result.nextCursor()); - } - return Mono.empty(); + String next = result.nextCursor(); + return (next != null && !next.isEmpty()) ? this.listTools(next) : Mono.empty(); }).reduce(new McpSchema.ListToolsResult(new ArrayList<>(), null), (allToolsResult, result) -> { allToolsResult.tools().addAll(result.tools()); return allToolsResult; - }); + }).map(result -> new McpSchema.ListToolsResult(Collections.unmodifiableList(result.tools()), null)); } /** @@ -691,14 +644,26 @@ public Mono listTools() { * @return A Mono that emits the list of tools result */ public Mono listTools(String cursor) { - return this.withSession("listing tools", init -> { - if (init.get().capabilities().tools() == null) { - return Mono.error(new McpError("Server does not provide tools capability")); - } - return init.mcpSession() - .sendRequest(McpSchema.METHOD_TOOLS_LIST, new McpSchema.PaginatedRequest(cursor), - LIST_TOOLS_RESULT_TYPE_REF); - }); + return this.initializer.withInitialization("listing tools", init -> this.listToolsInternal(init, cursor)); + } + + private Mono listToolsInternal(Initialization init, String cursor) { + + if (init.initializeResult().capabilities().tools() == null) { + return Mono.error(new IllegalStateException("Server does not provide tools capability")); + } + return init.mcpSession() + .sendRequest(McpSchema.METHOD_TOOLS_LIST, new McpSchema.PaginatedRequest(cursor), + LIST_TOOLS_RESULT_TYPE_REF) + .doOnNext(result -> { + if (this.enableCallToolSchemaCaching && result.tools() != null) { + // Cache tools output schema + result.tools() + .stream() + .filter(tool -> tool.outputSchema() != null) + .forEach(tool -> this.toolsOutputSchemaCache.put(tool.name(), tool.outputSchema())); + } + }); } private NotificationHandler asyncToolsChangeNotificationHandler( @@ -718,13 +683,13 @@ private NotificationHandler asyncToolsChangeNotificationHandler( // Resources // -------------------------- - private static final TypeReference LIST_RESOURCES_RESULT_TYPE_REF = new TypeReference<>() { + private static final TypeRef LIST_RESOURCES_RESULT_TYPE_REF = new TypeRef<>() { }; - private static final TypeReference READ_RESOURCE_RESULT_TYPE_REF = new TypeReference<>() { + private static final TypeRef READ_RESOURCE_RESULT_TYPE_REF = new TypeRef<>() { }; - private static final TypeReference LIST_RESOURCE_TEMPLATES_RESULT_TYPE_REF = new TypeReference<>() { + private static final TypeRef LIST_RESOURCE_TEMPLATES_RESULT_TYPE_REF = new TypeRef<>() { }; /** @@ -736,15 +701,13 @@ private NotificationHandler asyncToolsChangeNotificationHandler( * @see #readResource(McpSchema.Resource) */ public Mono listResources() { - return this.listResources(McpSchema.FIRST_PAGE).expand(result -> { - if (result.nextCursor() != null) { - return this.listResources(result.nextCursor()); - } - return Mono.empty(); - }).reduce(new McpSchema.ListResourcesResult(new ArrayList<>(), null), (allResourcesResult, result) -> { - allResourcesResult.resources().addAll(result.resources()); - return allResourcesResult; - }); + return this.listResources(McpSchema.FIRST_PAGE) + .expand(result -> (result.nextCursor() != null) ? this.listResources(result.nextCursor()) : Mono.empty()) + .reduce(new McpSchema.ListResourcesResult(new ArrayList<>(), null), (allResourcesResult, result) -> { + allResourcesResult.resources().addAll(result.resources()); + return allResourcesResult; + }) + .map(result -> new McpSchema.ListResourcesResult(Collections.unmodifiableList(result.resources()), null)); } /** @@ -757,9 +720,9 @@ public Mono listResources() { * @see #readResource(McpSchema.Resource) */ public Mono listResources(String cursor) { - return this.withSession("listing resources", init -> { - if (init.get().capabilities().resources() == null) { - return Mono.error(new McpError("Server does not provide the resources capability")); + return this.initializer.withInitialization("listing resources", init -> { + if (init.initializeResult().capabilities().resources() == null) { + return Mono.error(new IllegalStateException("Server does not provide the resources capability")); } return init.mcpSession() .sendRequest(McpSchema.METHOD_RESOURCES_LIST, new McpSchema.PaginatedRequest(cursor), @@ -789,9 +752,9 @@ public Mono readResource(McpSchema.Resource resour * @see McpSchema.ReadResourceResult */ public Mono readResource(McpSchema.ReadResourceRequest readResourceRequest) { - return this.withSession("reading resources", init -> { - if (init.get().capabilities().resources() == null) { - return Mono.error(new McpError("Server does not provide the resources capability")); + return this.initializer.withInitialization("reading resources", init -> { + if (init.initializeResult().capabilities().resources() == null) { + return Mono.error(new IllegalStateException("Server does not provide the resources capability")); } return init.mcpSession() .sendRequest(McpSchema.METHOD_RESOURCES_READ, readResourceRequest, READ_RESOURCE_RESULT_TYPE_REF); @@ -806,17 +769,16 @@ public Mono readResource(McpSchema.ReadResourceReq * @see McpSchema.ListResourceTemplatesResult */ public Mono listResourceTemplates() { - return this.listResourceTemplates(McpSchema.FIRST_PAGE).expand(result -> { - if (result.nextCursor() != null) { - return this.listResourceTemplates(result.nextCursor()); - } - return Mono.empty(); - }) + return this.listResourceTemplates(McpSchema.FIRST_PAGE) + .expand(result -> (result.nextCursor() != null) ? this.listResourceTemplates(result.nextCursor()) + : Mono.empty()) .reduce(new McpSchema.ListResourceTemplatesResult(new ArrayList<>(), null), (allResourceTemplatesResult, result) -> { allResourceTemplatesResult.resourceTemplates().addAll(result.resourceTemplates()); return allResourceTemplatesResult; - }); + }) + .map(result -> new McpSchema.ListResourceTemplatesResult( + Collections.unmodifiableList(result.resourceTemplates()), null)); } /** @@ -828,9 +790,9 @@ public Mono listResourceTemplates() { * @see McpSchema.ListResourceTemplatesResult */ public Mono listResourceTemplates(String cursor) { - return this.withSession("listing resource templates", init -> { - if (init.get().capabilities().resources() == null) { - return Mono.error(new McpError("Server does not provide the resources capability")); + return this.initializer.withInitialization("listing resource templates", init -> { + if (init.initializeResult().capabilities().resources() == null) { + return Mono.error(new IllegalStateException("Server does not provide the resources capability")); } return init.mcpSession() .sendRequest(McpSchema.METHOD_RESOURCES_TEMPLATES_LIST, new McpSchema.PaginatedRequest(cursor), @@ -848,7 +810,7 @@ public Mono listResourceTemplates(String * @see #unsubscribeResource(McpSchema.UnsubscribeRequest) */ public Mono subscribeResource(McpSchema.SubscribeRequest subscribeRequest) { - return this.withSession("subscribing to resources", init -> init.mcpSession() + return this.initializer.withInitialization("subscribing to resources", init -> init.mcpSession() .sendRequest(McpSchema.METHOD_RESOURCES_SUBSCRIBE, subscribeRequest, VOID_TYPE_REFERENCE)); } @@ -862,7 +824,7 @@ public Mono subscribeResource(McpSchema.SubscribeRequest subscribeRequest) * @see #subscribeResource(McpSchema.SubscribeRequest) */ public Mono unsubscribeResource(McpSchema.UnsubscribeRequest unsubscribeRequest) { - return this.withSession("unsubscribing from resources", init -> init.mcpSession() + return this.initializer.withInitialization("unsubscribing from resources", init -> init.mcpSession() .sendRequest(McpSchema.METHOD_RESOURCES_UNSUBSCRIBE, unsubscribeRequest, VOID_TYPE_REFERENCE)); } @@ -881,7 +843,7 @@ private NotificationHandler asyncResourcesUpdatedNotificationHandler( List, Mono>> resourcesUpdateConsumers) { return params -> { McpSchema.ResourcesUpdatedNotification resourcesUpdatedNotification = transport.unmarshalFrom(params, - new TypeReference<>() { + new TypeRef<>() { }); return readResource(new McpSchema.ReadResourceRequest(resourcesUpdatedNotification.uri())) @@ -898,10 +860,10 @@ private NotificationHandler asyncResourcesUpdatedNotificationHandler( // -------------------------- // Prompts // -------------------------- - private static final TypeReference LIST_PROMPTS_RESULT_TYPE_REF = new TypeReference<>() { + private static final TypeRef LIST_PROMPTS_RESULT_TYPE_REF = new TypeRef<>() { }; - private static final TypeReference GET_PROMPT_RESULT_TYPE_REF = new TypeReference<>() { + private static final TypeRef GET_PROMPT_RESULT_TYPE_REF = new TypeRef<>() { }; /** @@ -911,15 +873,13 @@ private NotificationHandler asyncResourcesUpdatedNotificationHandler( * @see #getPrompt(GetPromptRequest) */ public Mono listPrompts() { - return this.listPrompts(McpSchema.FIRST_PAGE).expand(result -> { - if (result.nextCursor() != null) { - return this.listPrompts(result.nextCursor()); - } - return Mono.empty(); - }).reduce(new ListPromptsResult(new ArrayList<>(), null), (allPromptsResult, result) -> { - allPromptsResult.prompts().addAll(result.prompts()); - return allPromptsResult; - }); + return this.listPrompts(McpSchema.FIRST_PAGE) + .expand(result -> (result.nextCursor() != null) ? this.listPrompts(result.nextCursor()) : Mono.empty()) + .reduce(new ListPromptsResult(new ArrayList<>(), null), (allPromptsResult, result) -> { + allPromptsResult.prompts().addAll(result.prompts()); + return allPromptsResult; + }) + .map(result -> new McpSchema.ListPromptsResult(Collections.unmodifiableList(result.prompts()), null)); } /** @@ -930,7 +890,7 @@ public Mono listPrompts() { * @see #getPrompt(GetPromptRequest) */ public Mono listPrompts(String cursor) { - return this.withSession("listing prompts", init -> init.mcpSession() + return this.initializer.withInitialization("listing prompts", init -> init.mcpSession() .sendRequest(McpSchema.METHOD_PROMPT_LIST, new PaginatedRequest(cursor), LIST_PROMPTS_RESULT_TYPE_REF)); } @@ -944,7 +904,7 @@ public Mono listPrompts(String cursor) { * @see #listPrompts() */ public Mono getPrompt(GetPromptRequest getPromptRequest) { - return this.withSession("getting prompts", init -> init.mcpSession() + return this.initializer.withInitialization("getting prompts", init -> init.mcpSession() .sendRequest(McpSchema.METHOD_PROMPT_GET, getPromptRequest, GET_PROMPT_RESULT_TYPE_REF)); } @@ -962,14 +922,6 @@ private NotificationHandler asyncPromptsChangeNotificationHandler( // -------------------------- // Logging // -------------------------- - /** - * Create a notification handler for logging notifications from the server. This - * handler automatically distributes logging messages to all registered consumers. - * @param loggingConsumers List of consumers that will be notified when a logging - * message is received. Each consumer receives the logging message notification. - * @return A NotificationHandler that processes log notifications by distributing the - * message to all registered consumers - */ private NotificationHandler asyncLoggingNotificationHandler( List>> loggingConsumers) { @@ -992,28 +944,44 @@ private NotificationHandler asyncLoggingNotificationHandler( */ public Mono setLoggingLevel(LoggingLevel loggingLevel) { if (loggingLevel == null) { - return Mono.error(new McpError("Logging level must not be null")); + return Mono.error(new IllegalArgumentException("Logging level must not be null")); } - return this.withSession("setting logging level", init -> { + return this.initializer.withInitialization("setting logging level", init -> { + if (init.initializeResult().capabilities().logging() == null) { + return Mono.error(new IllegalStateException("Server's Logging capabilities are not enabled!")); + } var params = new McpSchema.SetLevelRequest(loggingLevel); return init.mcpSession().sendRequest(McpSchema.METHOD_LOGGING_SET_LEVEL, params, OBJECT_TYPE_REF).then(); }); } + private NotificationHandler asyncProgressNotificationHandler( + List>> progressConsumers) { + + return params -> { + McpSchema.ProgressNotification progressNotification = transport.unmarshalFrom(params, + PROGRESS_NOTIFICATION_TYPE_REF); + + return Flux.fromIterable(progressConsumers) + .flatMap(consumer -> consumer.apply(progressNotification)) + .then(); + }; + } + /** * This method is package-private and used for test only. Should not be called by user * code. * @param protocolVersions the Client supported protocol versions. */ void setProtocolVersions(List protocolVersions) { - this.protocolVersions = protocolVersions; + this.initializer.setProtocolVersions(protocolVersions); } // -------------------------- // Completions // -------------------------- - private static final TypeReference COMPLETION_COMPLETE_RESULT_TYPE_REF = new TypeReference<>() { + private static final TypeRef COMPLETION_COMPLETE_RESULT_TYPE_REF = new TypeRef<>() { }; /** @@ -1027,7 +995,7 @@ void setProtocolVersions(List protocolVersions) { * @see McpSchema.CompleteResult */ public Mono completeCompletion(McpSchema.CompleteRequest completeRequest) { - return this.withSession("complete completions", init -> init.mcpSession() + return this.initializer.withInitialization("complete completions", init -> init.mcpSession() .sendRequest(McpSchema.METHOD_COMPLETION_COMPLETE, completeRequest, COMPLETION_COMPLETE_RESULT_TYPE_REF)); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/McpClient.java similarity index 77% rename from mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java rename to mcp-core/src/main/java/io/modelcontextprotocol/client/McpClient.java index d8925b005..c9989f832 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/McpClient.java @@ -4,17 +4,10 @@ package io.modelcontextprotocol.client; -import java.time.Duration; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.function.Consumer; -import java.util.function.Function; - +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.json.schema.JsonSchemaValidator; import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpTransport; import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; @@ -22,9 +15,19 @@ import io.modelcontextprotocol.spec.McpSchema.ElicitResult; import io.modelcontextprotocol.spec.McpSchema.Implementation; import io.modelcontextprotocol.spec.McpSchema.Root; +import io.modelcontextprotocol.spec.McpTransport; import io.modelcontextprotocol.util.Assert; import reactor.core.publisher.Mono; +import java.time.Duration; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.Supplier; + /** * Factory class for creating Model Context Protocol (MCP) clients. MCP is a protocol that * enables AI models to interact with external tools and resources through a standardized @@ -72,6 +75,7 @@ * .resourcesChangeConsumer(resources -> Mono.fromRunnable(() -> System.out.println("Resources updated: " + resources))) * .promptsChangeConsumer(prompts -> Mono.fromRunnable(() -> System.out.println("Prompts updated: " + prompts))) * .loggingConsumer(message -> Mono.fromRunnable(() -> System.out.println("Log message: " + message))) + * .resourcesUpdateConsumer(resourceContents -> Mono.fromRunnable(() -> System.out.println("Resources contents updated: " + resourceContents))) * .build(); * } * @@ -97,6 +101,7 @@ * * @author Christian Tzolov * @author Dariusz Jędrzejczyk + * @author Anurag Pant * @see McpAsyncClient * @see McpSyncClient * @see McpTransport @@ -163,7 +168,7 @@ class SyncSpec { private ClientCapabilities capabilities; - private Implementation clientInfo = new Implementation("Java SDK MCP Client", "1.0.0"); + private Implementation clientInfo = new Implementation("Java SDK MCP Client", "0.15.0"); private final Map roots = new HashMap<>(); @@ -177,10 +182,18 @@ class SyncSpec { private final List> loggingConsumers = new ArrayList<>(); + private final List> progressConsumers = new ArrayList<>(); + private Function samplingHandler; private Function elicitationHandler; + private Supplier contextProvider = () -> McpTransportContext.EMPTY; + + private JsonSchemaValidator jsonSchemaValidator; + + private boolean enableCallToolSchemaCaching = false; // Default to false + private SyncSpec(McpClientTransport transport) { Assert.notNull(transport, "Transport must not be null"); this.transport = transport; @@ -334,6 +347,22 @@ public SyncSpec resourcesChangeConsumer(Consumer> resou return this; } + /** + * Adds a consumer to be notified when a specific resource is updated. This allows + * the client to react to changes in individual resources, such as updates to + * their content or metadata. + * @param resourcesUpdateConsumer A consumer function that processes the updated + * resource and returns a Mono indicating the completion of the processing. Must + * not be null. + * @return This builder instance for method chaining. + * @throws IllegalArgumentException If the resourcesUpdateConsumer is null. + */ + public SyncSpec resourcesUpdateConsumer(Consumer> resourcesUpdateConsumer) { + Assert.notNull(resourcesUpdateConsumer, "Resources update consumer must not be null"); + this.resourcesUpdateConsumers.add(resourcesUpdateConsumer); + return this; + } + /** * Adds a consumer to be notified when the available prompts change. This allows * the client to react to changes in the server's prompt templates, such as new @@ -377,6 +406,78 @@ public SyncSpec loggingConsumers(List progressConsumer) { + Assert.notNull(progressConsumer, "Progress consumer must not be null"); + this.progressConsumers.add(progressConsumer); + return this; + } + + /** + * Adds a multiple consumers to be notified of progress notifications from the + * server. This allows the client to track long-running operations and provide + * feedback to users. + * @param progressConsumers A list of consumers that receives progress + * notifications. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if progressConsumer is null + */ + public SyncSpec progressConsumers(List> progressConsumers) { + Assert.notNull(progressConsumers, "Progress consumers must not be null"); + this.progressConsumers.addAll(progressConsumers); + return this; + } + + /** + * Add a provider of {@link McpTransportContext}, providing a context before + * calling any client operation. This allows to extract thread-locals and hand + * them over to the underlying transport. + *

+ * There is no direct equivalent in {@link AsyncSpec}. To achieve the same result, + * append {@code contextWrite(McpTransportContext.KEY, context)} to any + * {@link McpAsyncClient} call. + * @param contextProvider A supplier to create a context + * @return This builder for method chaining + */ + public SyncSpec transportContextProvider(Supplier contextProvider) { + this.contextProvider = contextProvider; + return this; + } + + /** + * Add a {@link JsonSchemaValidator} to validate the JSON structure of the + * structured output. + * @param jsonSchemaValidator A validator to validate the JSON structure of the + * structured output. Must not be null. + * @return This builder for method chaining + * @throws IllegalArgumentException if jsonSchemaValidator is null + */ + public SyncSpec jsonSchemaValidator(JsonSchemaValidator jsonSchemaValidator) { + Assert.notNull(jsonSchemaValidator, "JsonSchemaValidator must not be null"); + this.jsonSchemaValidator = jsonSchemaValidator; + return this; + } + + /** + * Enables automatic schema caching during callTool operations. When a tool's + * output schema is not found in the cache, callTool will automatically fetch and + * cache all tool schemas via listTools. + * @param enableCallToolSchemaCaching true to enable, false to disable + * @return This builder instance for method chaining + */ + public SyncSpec enableCallToolSchemaCaching(boolean enableCallToolSchemaCaching) { + this.enableCallToolSchemaCaching = enableCallToolSchemaCaching; + return this; + } + /** * Create an instance of {@link McpSyncClient} with the provided configurations or * sensible defaults. @@ -385,12 +486,14 @@ public SyncSpec loggingConsumers(List roots = new HashMap<>(); @@ -435,10 +538,16 @@ class AsyncSpec { private final List>> loggingConsumers = new ArrayList<>(); + private final List>> progressConsumers = new ArrayList<>(); + private Function> samplingHandler; private Function> elicitationHandler; + private JsonSchemaValidator jsonSchemaValidator; + + private boolean enableCallToolSchemaCaching = false; // Default to false + private AsyncSpec(McpClientTransport transport) { Assert.notNull(transport, "Transport must not be null"); this.transport = transport; @@ -654,17 +763,76 @@ public AsyncSpec loggingConsumers( return this; } + /** + * Adds a consumer to be notified of progress notifications from the server. This + * allows the client to track long-running operations and provide feedback to + * users. + * @param progressConsumer A consumer that receives progress notifications. Must + * not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if progressConsumer is null + */ + public AsyncSpec progressConsumer(Function> progressConsumer) { + Assert.notNull(progressConsumer, "Progress consumer must not be null"); + this.progressConsumers.add(progressConsumer); + return this; + } + + /** + * Adds a multiple consumers to be notified of progress notifications from the + * server. This allows the client to track long-running operations and provide + * feedback to users. + * @param progressConsumers A list of consumers that receives progress + * notifications. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if progressConsumer is null + */ + public AsyncSpec progressConsumers( + List>> progressConsumers) { + Assert.notNull(progressConsumers, "Progress consumers must not be null"); + this.progressConsumers.addAll(progressConsumers); + return this; + } + + /** + * Sets the JSON schema validator to use for validating tool responses against + * output schemas. + * @param jsonSchemaValidator The validator to use. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if jsonSchemaValidator is null + */ + public AsyncSpec jsonSchemaValidator(JsonSchemaValidator jsonSchemaValidator) { + Assert.notNull(jsonSchemaValidator, "JsonSchemaValidator must not be null"); + this.jsonSchemaValidator = jsonSchemaValidator; + return this; + } + + /** + * Enables automatic schema caching during callTool operations. When a tool's + * output schema is not found in the cache, callTool will automatically fetch and + * cache all tool schemas via listTools. + * @param enableCallToolSchemaCaching true to enable, false to disable + * @return This builder instance for method chaining + */ + public AsyncSpec enableCallToolSchemaCaching(boolean enableCallToolSchemaCaching) { + this.enableCallToolSchemaCaching = enableCallToolSchemaCaching; + return this; + } + /** * Create an instance of {@link McpAsyncClient} with the provided configurations * or sensible defaults. * @return a new instance of {@link McpAsyncClient}. */ public McpAsyncClient build() { + var jsonSchemaValidator = (this.jsonSchemaValidator != null) ? this.jsonSchemaValidator + : JsonSchemaValidator.getDefault(); return new McpAsyncClient(this.transport, this.requestTimeout, this.initializationTimeout, + jsonSchemaValidator, new McpClientFeatures.Async(this.clientInfo, this.capabilities, this.roots, this.toolsChangeConsumers, this.resourcesChangeConsumers, this.resourcesUpdateConsumers, - this.promptsChangeConsumers, this.loggingConsumers, this.samplingHandler, - this.elicitationHandler)); + this.promptsChangeConsumers, this.loggingConsumers, this.progressConsumers, + this.samplingHandler, this.elicitationHandler, this.enableCallToolSchemaCaching)); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java similarity index 75% rename from mcp/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java rename to mcp-core/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java index bd1a0985a..127d53337 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java @@ -59,8 +59,10 @@ class McpClientFeatures { * @param resourcesChangeConsumers the resources change consumers. * @param promptsChangeConsumers the prompts change consumers. * @param loggingConsumers the logging consumers. + * @param progressConsumers the progress consumers. * @param samplingHandler the sampling handler. * @param elicitationHandler the elicitation handler. + * @param enableCallToolSchemaCaching whether to enable call tool schema caching. */ record Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities clientCapabilities, Map roots, List, Mono>> toolsChangeConsumers, @@ -68,8 +70,10 @@ record Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c List, Mono>> resourcesUpdateConsumers, List, Mono>> promptsChangeConsumers, List>> loggingConsumers, + List>> progressConsumers, Function> samplingHandler, - Function> elicitationHandler) { + Function> elicitationHandler, + boolean enableCallToolSchemaCaching) { /** * Create an instance and validate the arguments. @@ -79,8 +83,10 @@ record Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c * @param resourcesChangeConsumers the resources change consumers. * @param promptsChangeConsumers the prompts change consumers. * @param loggingConsumers the logging consumers. + * @param progressConsumers the progress consumers. * @param samplingHandler the sampling handler. * @param elicitationHandler the elicitation handler. + * @param enableCallToolSchemaCaching whether to enable call tool schema caching. */ public Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities clientCapabilities, Map roots, @@ -89,8 +95,10 @@ public Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c List, Mono>> resourcesUpdateConsumers, List, Mono>> promptsChangeConsumers, List>> loggingConsumers, + List>> progressConsumers, Function> samplingHandler, - Function> elicitationHandler) { + Function> elicitationHandler, + boolean enableCallToolSchemaCaching) { Assert.notNull(clientInfo, "Client info must not be null"); this.clientInfo = clientInfo; @@ -106,8 +114,27 @@ public Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c this.resourcesUpdateConsumers = resourcesUpdateConsumers != null ? resourcesUpdateConsumers : List.of(); this.promptsChangeConsumers = promptsChangeConsumers != null ? promptsChangeConsumers : List.of(); this.loggingConsumers = loggingConsumers != null ? loggingConsumers : List.of(); + this.progressConsumers = progressConsumers != null ? progressConsumers : List.of(); this.samplingHandler = samplingHandler; this.elicitationHandler = elicitationHandler; + this.enableCallToolSchemaCaching = enableCallToolSchemaCaching; + } + + /** + * @deprecated Only exists for backwards-compatibility purposes. + */ + public Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities clientCapabilities, + Map roots, + List, Mono>> toolsChangeConsumers, + List, Mono>> resourcesChangeConsumers, + List, Mono>> resourcesUpdateConsumers, + List, Mono>> promptsChangeConsumers, + List>> loggingConsumers, + Function> samplingHandler, + Function> elicitationHandler) { + this(clientInfo, clientCapabilities, roots, toolsChangeConsumers, resourcesChangeConsumers, + resourcesUpdateConsumers, promptsChangeConsumers, loggingConsumers, List.of(), samplingHandler, + elicitationHandler, false); } /** @@ -149,6 +176,12 @@ public static Async fromSync(Sync syncSpec) { .subscribeOn(Schedulers.boundedElastic())); } + List>> progressConsumers = new ArrayList<>(); + for (Consumer consumer : syncSpec.progressConsumers()) { + progressConsumers.add(l -> Mono.fromRunnable(() -> consumer.accept(l)) + .subscribeOn(Schedulers.boundedElastic())); + } + Function> samplingHandler = r -> Mono .fromCallable(() -> syncSpec.samplingHandler().apply(r)) .subscribeOn(Schedulers.boundedElastic()); @@ -159,7 +192,8 @@ public static Async fromSync(Sync syncSpec) { return new Async(syncSpec.clientInfo(), syncSpec.clientCapabilities(), syncSpec.roots(), toolsChangeConsumers, resourcesChangeConsumers, resourcesUpdateConsumers, promptsChangeConsumers, - loggingConsumers, samplingHandler, elicitationHandler); + loggingConsumers, progressConsumers, samplingHandler, elicitationHandler, + syncSpec.enableCallToolSchemaCaching); } } @@ -174,8 +208,10 @@ public static Async fromSync(Sync syncSpec) { * @param resourcesChangeConsumers the resources change consumers. * @param promptsChangeConsumers the prompts change consumers. * @param loggingConsumers the logging consumers. + * @param progressConsumers the progress consumers. * @param samplingHandler the sampling handler. * @param elicitationHandler the elicitation handler. + * @param enableCallToolSchemaCaching whether to enable call tool schema caching. */ public record Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities clientCapabilities, Map roots, List>> toolsChangeConsumers, @@ -183,8 +219,10 @@ public record Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabili List>> resourcesUpdateConsumers, List>> promptsChangeConsumers, List> loggingConsumers, + List> progressConsumers, Function samplingHandler, - Function elicitationHandler) { + Function elicitationHandler, + boolean enableCallToolSchemaCaching) { /** * Create an instance and validate the arguments. @@ -196,8 +234,10 @@ public record Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabili * @param resourcesUpdateConsumers the resource update consumers. * @param promptsChangeConsumers the prompts change consumers. * @param loggingConsumers the logging consumers. + * @param progressConsumers the progress consumers. * @param samplingHandler the sampling handler. * @param elicitationHandler the elicitation handler. + * @param enableCallToolSchemaCaching whether to enable call tool schema caching. */ public Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities clientCapabilities, Map roots, List>> toolsChangeConsumers, @@ -205,8 +245,10 @@ public Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities cl List>> resourcesUpdateConsumers, List>> promptsChangeConsumers, List> loggingConsumers, + List> progressConsumers, Function samplingHandler, - Function elicitationHandler) { + Function elicitationHandler, + boolean enableCallToolSchemaCaching) { Assert.notNull(clientInfo, "Client info must not be null"); this.clientInfo = clientInfo; @@ -222,8 +264,26 @@ public Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities cl this.resourcesUpdateConsumers = resourcesUpdateConsumers != null ? resourcesUpdateConsumers : List.of(); this.promptsChangeConsumers = promptsChangeConsumers != null ? promptsChangeConsumers : List.of(); this.loggingConsumers = loggingConsumers != null ? loggingConsumers : List.of(); + this.progressConsumers = progressConsumers != null ? progressConsumers : List.of(); this.samplingHandler = samplingHandler; this.elicitationHandler = elicitationHandler; + this.enableCallToolSchemaCaching = enableCallToolSchemaCaching; + } + + /** + * @deprecated Only exists for backwards-compatibility purposes. + */ + public Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities clientCapabilities, + Map roots, List>> toolsChangeConsumers, + List>> resourcesChangeConsumers, + List>> resourcesUpdateConsumers, + List>> promptsChangeConsumers, + List> loggingConsumers, + Function samplingHandler, + Function elicitationHandler) { + this(clientInfo, clientCapabilities, roots, toolsChangeConsumers, resourcesChangeConsumers, + resourcesUpdateConsumers, promptsChangeConsumers, loggingConsumers, List.of(), samplingHandler, + elicitationHandler, false); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java similarity index 81% rename from mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java rename to mcp-core/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java index 27b020f05..7fdaa8941 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java @@ -5,17 +5,19 @@ package io.modelcontextprotocol.client; import java.time.Duration; -import java.util.List; +import java.util.function.Supplier; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import io.modelcontextprotocol.common.McpTransportContext; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; import io.modelcontextprotocol.spec.McpSchema.GetPromptRequest; import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; import io.modelcontextprotocol.spec.McpSchema.ListPromptsResult; import io.modelcontextprotocol.util.Assert; +import reactor.core.publisher.Mono; /** * A synchronous client implementation for the Model Context Protocol (MCP) that wraps an @@ -64,14 +66,28 @@ public class McpSyncClient implements AutoCloseable { private final McpAsyncClient delegate; + private final Supplier contextProvider; + /** * Create a new McpSyncClient with the given delegate. * @param delegate the asynchronous kernel on top of which this synchronous client * provides a blocking API. + * @param contextProvider the supplier of context before calling any non-blocking + * operation on underlying delegate */ - McpSyncClient(McpAsyncClient delegate) { + McpSyncClient(McpAsyncClient delegate, Supplier contextProvider) { Assert.notNull(delegate, "The delegate can not be null"); + Assert.notNull(contextProvider, "The contextProvider can not be null"); this.delegate = delegate; + this.contextProvider = contextProvider; + } + + /** + * Get the current initialization result. + * @return the initialization result. + */ + public McpSchema.InitializeResult getCurrentInitializationResult() { + return this.delegate.getCurrentInitializationResult(); } /** @@ -170,14 +186,14 @@ public boolean closeGracefully() { public McpSchema.InitializeResult initialize() { // TODO: block takes no argument here as we assume the async client is // configured with a requestTimeout at all times - return this.delegate.initialize().block(); + return withProvidedContext(this.delegate.initialize()).block(); } /** * Send a roots/list_changed notification. */ public void rootsListChangedNotification() { - this.delegate.rootsListChangedNotification().block(); + withProvidedContext(this.delegate.rootsListChangedNotification()).block(); } /** @@ -199,7 +215,7 @@ public void removeRoot(String rootUri) { * @return */ public Object ping() { - return this.delegate.ping().block(); + return withProvidedContext(this.delegate.ping()).block(); } // -------------------------- @@ -217,7 +233,8 @@ public Object ping() { * Boolean indicating if the execution failed (true) or succeeded (false/absent) */ public McpSchema.CallToolResult callTool(McpSchema.CallToolRequest callToolRequest) { - return this.delegate.callTool(callToolRequest).block(); + return withProvidedContext(this.delegate.callTool(callToolRequest)).block(); + } /** @@ -227,7 +244,7 @@ public McpSchema.CallToolResult callTool(McpSchema.CallToolRequest callToolReque * pagination if more tools are available */ public McpSchema.ListToolsResult listTools() { - return this.delegate.listTools().block(); + return withProvidedContext(this.delegate.listTools()).block(); } /** @@ -238,7 +255,8 @@ public McpSchema.ListToolsResult listTools() { * pagination if more tools are available */ public McpSchema.ListToolsResult listTools(String cursor) { - return this.delegate.listTools(cursor).block(); + return withProvidedContext(this.delegate.listTools(cursor)).block(); + } // -------------------------- @@ -250,7 +268,8 @@ public McpSchema.ListToolsResult listTools(String cursor) { * @return The list of all resources result */ public McpSchema.ListResourcesResult listResources() { - return this.delegate.listResources().block(); + return withProvidedContext(this.delegate.listResources()).block(); + } /** @@ -259,7 +278,8 @@ public McpSchema.ListResourcesResult listResources() { * @return The list of resources result */ public McpSchema.ListResourcesResult listResources(String cursor) { - return this.delegate.listResources(cursor).block(); + return withProvidedContext(this.delegate.listResources(cursor)).block(); + } /** @@ -268,7 +288,8 @@ public McpSchema.ListResourcesResult listResources(String cursor) { * @return the resource content. */ public McpSchema.ReadResourceResult readResource(McpSchema.Resource resource) { - return this.delegate.readResource(resource).block(); + return withProvidedContext(this.delegate.readResource(resource)).block(); + } /** @@ -277,7 +298,8 @@ public McpSchema.ReadResourceResult readResource(McpSchema.Resource resource) { * @return the resource content. */ public McpSchema.ReadResourceResult readResource(McpSchema.ReadResourceRequest readResourceRequest) { - return this.delegate.readResource(readResourceRequest).block(); + return withProvidedContext(this.delegate.readResource(readResourceRequest)).block(); + } /** @@ -285,7 +307,8 @@ public McpSchema.ReadResourceResult readResource(McpSchema.ReadResourceRequest r * @return The list of all resource templates result. */ public McpSchema.ListResourceTemplatesResult listResourceTemplates() { - return this.delegate.listResourceTemplates().block(); + return withProvidedContext(this.delegate.listResourceTemplates()).block(); + } /** @@ -297,7 +320,8 @@ public McpSchema.ListResourceTemplatesResult listResourceTemplates() { * @return The list of resource templates result. */ public McpSchema.ListResourceTemplatesResult listResourceTemplates(String cursor) { - return this.delegate.listResourceTemplates(cursor).block(); + return withProvidedContext(this.delegate.listResourceTemplates(cursor)).block(); + } /** @@ -310,7 +334,8 @@ public McpSchema.ListResourceTemplatesResult listResourceTemplates(String cursor * subscribe to. */ public void subscribeResource(McpSchema.SubscribeRequest subscribeRequest) { - this.delegate.subscribeResource(subscribeRequest).block(); + withProvidedContext(this.delegate.subscribeResource(subscribeRequest)).block(); + } /** @@ -319,7 +344,8 @@ public void subscribeResource(McpSchema.SubscribeRequest subscribeRequest) { * to unsubscribe from. */ public void unsubscribeResource(McpSchema.UnsubscribeRequest unsubscribeRequest) { - this.delegate.unsubscribeResource(unsubscribeRequest).block(); + withProvidedContext(this.delegate.unsubscribeResource(unsubscribeRequest)).block(); + } // -------------------------- @@ -331,7 +357,7 @@ public void unsubscribeResource(McpSchema.UnsubscribeRequest unsubscribeRequest) * @return The list of all prompts result. */ public ListPromptsResult listPrompts() { - return this.delegate.listPrompts().block(); + return withProvidedContext(this.delegate.listPrompts()).block(); } /** @@ -340,11 +366,12 @@ public ListPromptsResult listPrompts() { * @return The list of prompts result. */ public ListPromptsResult listPrompts(String cursor) { - return this.delegate.listPrompts(cursor).block(); + return withProvidedContext(this.delegate.listPrompts(cursor)).block(); + } public GetPromptResult getPrompt(GetPromptRequest getPromptRequest) { - return this.delegate.getPrompt(getPromptRequest).block(); + return withProvidedContext(this.delegate.getPrompt(getPromptRequest)).block(); } /** @@ -352,7 +379,8 @@ public GetPromptResult getPrompt(GetPromptRequest getPromptRequest) { * @param loggingLevel the min logging level */ public void setLoggingLevel(McpSchema.LoggingLevel loggingLevel) { - this.delegate.setLoggingLevel(loggingLevel).block(); + withProvidedContext(this.delegate.setLoggingLevel(loggingLevel)).block(); + } /** @@ -362,7 +390,18 @@ public void setLoggingLevel(McpSchema.LoggingLevel loggingLevel) { * @return the completion result containing suggested values. */ public McpSchema.CompleteResult completeCompletion(McpSchema.CompleteRequest completeRequest) { - return this.delegate.completeCompletion(completeRequest).block(); + return withProvidedContext(this.delegate.completeCompletion(completeRequest)).block(); + + } + + /** + * For a given action, on assembly, capture the "context" via the + * {@link #contextProvider} and store it in the Reactor context. + * @param action the action to perform + * @return the result of the action + */ + private Mono withProvidedContext(Mono action) { + return action.contextWrite(ctx -> ctx.put(McpTransportContext.KEY, this.contextProvider.get())); } } diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java new file mode 100644 index 000000000..ae093316f --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java @@ -0,0 +1,513 @@ +/* + * Copyright 2024 - 2025 the original author or authors. + */ + +package io.modelcontextprotocol.client.transport; + +import java.io.IOException; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.time.Duration; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.function.Function; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import io.modelcontextprotocol.client.transport.ResponseSubscribers.ResponseEvent; +import io.modelcontextprotocol.client.transport.customizer.McpAsyncHttpClientRequestCustomizer; +import io.modelcontextprotocol.client.transport.customizer.McpSyncHttpClientRequestCustomizer; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.TypeRef; +import io.modelcontextprotocol.spec.HttpHeaders; +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; +import io.modelcontextprotocol.spec.McpTransportException; +import io.modelcontextprotocol.spec.ProtocolVersions; +import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.Utils; +import reactor.core.Disposable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; + +/** + * Server-Sent Events (SSE) implementation of the + * {@link io.modelcontextprotocol.spec.McpTransport} that follows the MCP HTTP with SSE + * transport specification, using Java's HttpClient. + * + *

+ * This transport implementation establishes a bidirectional communication channel between + * client and server using SSE for server-to-client messages and HTTP POST requests for + * client-to-server messages. The transport: + *

    + *
  • Establishes an SSE connection to receive server messages
  • + *
  • Handles endpoint discovery through SSE events
  • + *
  • Manages message serialization/deserialization using Jackson
  • + *
  • Provides graceful connection termination
  • + *
+ * + *

+ * The transport supports two types of SSE events: + *

    + *
  • 'endpoint' - Contains the URL for sending client messages
  • + *
  • 'message' - Contains JSON-RPC message payload
  • + *
+ * + * @author Christian Tzolov + * @see io.modelcontextprotocol.spec.McpTransport + * @see io.modelcontextprotocol.spec.McpClientTransport + */ +public class HttpClientSseClientTransport implements McpClientTransport { + + private static final String MCP_PROTOCOL_VERSION = ProtocolVersions.MCP_2024_11_05; + + private static final String MCP_PROTOCOL_VERSION_HEADER_NAME = "MCP-Protocol-Version"; + + private static final Logger logger = LoggerFactory.getLogger(HttpClientSseClientTransport.class); + + /** SSE event type for JSON-RPC messages */ + private static final String MESSAGE_EVENT_TYPE = "message"; + + /** SSE event type for endpoint discovery */ + private static final String ENDPOINT_EVENT_TYPE = "endpoint"; + + /** Default SSE endpoint path */ + private static final String DEFAULT_SSE_ENDPOINT = "/sse"; + + /** Base URI for the MCP server */ + private final URI baseUri; + + /** SSE endpoint path */ + private final String sseEndpoint; + + /** + * HTTP client for sending messages to the server. Uses HTTP POST over the message + * endpoint + */ + private final HttpClient httpClient; + + /** HTTP request builder for building requests to send messages to the server */ + private final HttpRequest.Builder requestBuilder; + + /** JSON mapper for message serialization/deserialization */ + protected McpJsonMapper jsonMapper; + + /** Flag indicating if the transport is in closing state */ + private volatile boolean isClosing = false; + + /** Holds the SSE subscription disposable */ + private final AtomicReference sseSubscription = new AtomicReference<>(); + + /** + * Sink for managing the message endpoint URI provided by the server. Stores the most + * recent endpoint URI and makes it available for outbound message processing. + */ + protected final Sinks.One messageEndpointSink = Sinks.one(); + + /** + * Customizer to modify requests before they are executed. + */ + private final McpAsyncHttpClientRequestCustomizer httpRequestCustomizer; + + /** + * Creates a new transport instance with custom HTTP client builder, object mapper, + * and headers. + * @param httpClient the HTTP client to use + * @param requestBuilder the HTTP request builder to use + * @param baseUri the base URI of the MCP server + * @param sseEndpoint the SSE endpoint path + * @param jsonMapper the object mapper for JSON serialization/deserialization + * @param httpRequestCustomizer customizer for the requestBuilder before executing + * requests + * @throws IllegalArgumentException if objectMapper, clientBuilder, or headers is null + */ + HttpClientSseClientTransport(HttpClient httpClient, HttpRequest.Builder requestBuilder, String baseUri, + String sseEndpoint, McpJsonMapper jsonMapper, McpAsyncHttpClientRequestCustomizer httpRequestCustomizer) { + Assert.notNull(jsonMapper, "jsonMapper must not be null"); + Assert.hasText(baseUri, "baseUri must not be empty"); + Assert.hasText(sseEndpoint, "sseEndpoint must not be empty"); + Assert.notNull(httpClient, "httpClient must not be null"); + Assert.notNull(requestBuilder, "requestBuilder must not be null"); + Assert.notNull(httpRequestCustomizer, "httpRequestCustomizer must not be null"); + this.baseUri = URI.create(baseUri); + this.sseEndpoint = sseEndpoint; + this.jsonMapper = jsonMapper; + this.httpClient = httpClient; + this.requestBuilder = requestBuilder; + this.httpRequestCustomizer = httpRequestCustomizer; + } + + @Override + public List protocolVersions() { + return List.of(ProtocolVersions.MCP_2024_11_05); + } + + /** + * Creates a new builder for {@link HttpClientSseClientTransport}. + * @param baseUri the base URI of the MCP server + * @return a new builder instance + */ + public static Builder builder(String baseUri) { + return new Builder().baseUri(baseUri); + } + + /** + * Builder for {@link HttpClientSseClientTransport}. + */ + public static class Builder { + + private String baseUri; + + private String sseEndpoint = DEFAULT_SSE_ENDPOINT; + + private HttpClient.Builder clientBuilder = HttpClient.newBuilder().version(HttpClient.Version.HTTP_1_1); + + private McpJsonMapper jsonMapper; + + private HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(); + + private McpAsyncHttpClientRequestCustomizer httpRequestCustomizer = McpAsyncHttpClientRequestCustomizer.NOOP; + + private Duration connectTimeout = Duration.ofSeconds(10); + + /** + * Creates a new builder instance. + */ + Builder() { + // Default constructor + } + + /** + * Creates a new builder with the specified base URI. + * @param baseUri the base URI of the MCP server + * @deprecated Use {@link HttpClientSseClientTransport#builder(String)} instead. + * This constructor is deprecated and will be removed or made {@code protected} or + * {@code private} in a future release. + */ + @Deprecated(forRemoval = true) + public Builder(String baseUri) { + Assert.hasText(baseUri, "baseUri must not be empty"); + this.baseUri = baseUri; + } + + /** + * Sets the base URI. + * @param baseUri the base URI + * @return this builder + */ + Builder baseUri(String baseUri) { + Assert.hasText(baseUri, "baseUri must not be empty"); + this.baseUri = baseUri; + return this; + } + + /** + * Sets the SSE endpoint path. + * @param sseEndpoint the SSE endpoint path + * @return this builder + */ + public Builder sseEndpoint(String sseEndpoint) { + Assert.hasText(sseEndpoint, "sseEndpoint must not be empty"); + this.sseEndpoint = sseEndpoint; + return this; + } + + /** + * Sets the HTTP client builder. + * @param clientBuilder the HTTP client builder + * @return this builder + */ + public Builder clientBuilder(HttpClient.Builder clientBuilder) { + Assert.notNull(clientBuilder, "clientBuilder must not be null"); + this.clientBuilder = clientBuilder; + return this; + } + + /** + * Customizes the HTTP client builder. + * @param clientCustomizer the consumer to customize the HTTP client builder + * @return this builder + */ + public Builder customizeClient(final Consumer clientCustomizer) { + Assert.notNull(clientCustomizer, "clientCustomizer must not be null"); + clientCustomizer.accept(clientBuilder); + return this; + } + + /** + * Sets the HTTP request builder. + * @param requestBuilder the HTTP request builder + * @return this builder + */ + public Builder requestBuilder(HttpRequest.Builder requestBuilder) { + Assert.notNull(requestBuilder, "requestBuilder must not be null"); + this.requestBuilder = requestBuilder; + return this; + } + + /** + * Customizes the HTTP client builder. + * @param requestCustomizer the consumer to customize the HTTP request builder + * @return this builder + */ + public Builder customizeRequest(final Consumer requestCustomizer) { + Assert.notNull(requestCustomizer, "requestCustomizer must not be null"); + requestCustomizer.accept(requestBuilder); + return this; + } + + /** + * Sets the JSON mapper implementation to use for serialization/deserialization. + * @param jsonMapper the JSON mapper + * @return this builder + */ + public Builder jsonMapper(McpJsonMapper jsonMapper) { + Assert.notNull(jsonMapper, "jsonMapper must not be null"); + this.jsonMapper = jsonMapper; + return this; + } + + /** + * Sets the customizer for {@link HttpRequest.Builder}, to modify requests before + * executing them. + *

+ * This overrides the customizer from + * {@link #asyncHttpRequestCustomizer(McpAsyncHttpClientRequestCustomizer)}. + *

+ * Do NOT use a blocking {@link McpSyncHttpClientRequestCustomizer} in a + * non-blocking context. Use + * {@link #asyncHttpRequestCustomizer(McpAsyncHttpClientRequestCustomizer)} + * instead. + * @param syncHttpRequestCustomizer the request customizer + * @return this builder + */ + public Builder httpRequestCustomizer(McpSyncHttpClientRequestCustomizer syncHttpRequestCustomizer) { + this.httpRequestCustomizer = McpAsyncHttpClientRequestCustomizer.fromSync(syncHttpRequestCustomizer); + return this; + } + + /** + * Sets the customizer for {@link HttpRequest.Builder}, to modify requests before + * executing them. + *

+ * This overrides the customizer from + * {@link #httpRequestCustomizer(McpSyncHttpClientRequestCustomizer)}. + *

+ * Do NOT use a blocking implementation in a non-blocking context. + * @param asyncHttpRequestCustomizer the request customizer + * @return this builder + */ + public Builder asyncHttpRequestCustomizer(McpAsyncHttpClientRequestCustomizer asyncHttpRequestCustomizer) { + this.httpRequestCustomizer = asyncHttpRequestCustomizer; + return this; + } + + /** + * Sets the connection timeout for the HTTP client. + * @param connectTimeout the connection timeout duration + * @return this builder + */ + public Builder connectTimeout(Duration connectTimeout) { + Assert.notNull(connectTimeout, "connectTimeout must not be null"); + this.connectTimeout = connectTimeout; + return this; + } + + /** + * Builds a new {@link HttpClientSseClientTransport} instance. + * @return a new transport instance + */ + public HttpClientSseClientTransport build() { + HttpClient httpClient = this.clientBuilder.connectTimeout(this.connectTimeout).build(); + return new HttpClientSseClientTransport(httpClient, requestBuilder, baseUri, sseEndpoint, + jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, httpRequestCustomizer); + } + + } + + @Override + public Mono connect(Function, Mono> handler) { + var uri = Utils.resolveUri(this.baseUri, this.sseEndpoint); + + return Mono.deferContextual(ctx -> { + var builder = requestBuilder.copy() + .uri(uri) + .header("Accept", "text/event-stream") + .header("Cache-Control", "no-cache") + .header(MCP_PROTOCOL_VERSION_HEADER_NAME, MCP_PROTOCOL_VERSION) + .GET(); + var transportContext = ctx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY); + return Mono.from(this.httpRequestCustomizer.customize(builder, "GET", uri, null, transportContext)); + }).flatMap(requestBuilder -> Mono.create(sink -> { + Disposable connection = Flux.create(sseSink -> this.httpClient + .sendAsync(requestBuilder.build(), + responseInfo -> ResponseSubscribers.sseToBodySubscriber(responseInfo, sseSink)) + .exceptionallyCompose(e -> { + sseSink.error(e); + return CompletableFuture.failedFuture(e); + })) + .map(responseEvent -> (ResponseSubscribers.SseResponseEvent) responseEvent) + .flatMap(responseEvent -> { + if (isClosing) { + return Mono.empty(); + } + + int statusCode = responseEvent.responseInfo().statusCode(); + + if (statusCode >= 200 && statusCode < 300) { + try { + if (ENDPOINT_EVENT_TYPE.equals(responseEvent.sseEvent().event())) { + String messageEndpointUri = responseEvent.sseEvent().data(); + if (this.messageEndpointSink.tryEmitValue(messageEndpointUri).isSuccess()) { + sink.success(); + return Flux.empty(); // No further processing needed + } + else { + sink.error(new RuntimeException("Failed to handle SSE endpoint event")); + } + } + else if (MESSAGE_EVENT_TYPE.equals(responseEvent.sseEvent().event())) { + JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(jsonMapper, + responseEvent.sseEvent().data()); + sink.success(); + return Flux.just(message); + } + else { + logger.debug("Received unrecognized SSE event type: {}", responseEvent.sseEvent()); + sink.success(); + } + } + catch (IOException e) { + sink.error(new McpTransportException("Error processing SSE event", e)); + } + } + return Flux.error( + new RuntimeException("Failed to send message: " + responseEvent)); + + }) + .flatMap(jsonRpcMessage -> handler.apply(Mono.just(jsonRpcMessage))) + .onErrorComplete(t -> { + if (!isClosing) { + logger.warn("SSE stream observed an error", t); + sink.error(t); + } + return true; + }) + .doFinally(s -> { + Disposable ref = this.sseSubscription.getAndSet(null); + if (ref != null && !ref.isDisposed()) { + ref.dispose(); + } + }) + .contextWrite(sink.contextView()) + .subscribe(); + + this.sseSubscription.set(connection); + })); + } + + /** + * Sends a JSON-RPC message to the server. + * + *

+ * This method waits for the message endpoint to be discovered before sending the + * message. The message is serialized to JSON and sent as an HTTP POST request. + * @param message the JSON-RPC message to send + * @return a Mono that completes when the message is sent + * @throws McpError if the message endpoint is not available or the wait times out + */ + @Override + public Mono sendMessage(JSONRPCMessage message) { + + return this.messageEndpointSink.asMono().flatMap(messageEndpointUri -> { + if (isClosing) { + return Mono.empty(); + } + + return this.serializeMessage(message) + .flatMap(body -> sendHttpPost(messageEndpointUri, body).handle((response, sink) -> { + if (response.statusCode() != 200 && response.statusCode() != 201 && response.statusCode() != 202 + && response.statusCode() != 206) { + sink.error(new RuntimeException("Sending message failed with a non-OK HTTP code: " + + response.statusCode() + " - " + response.body())); + } + else { + sink.next(response); + sink.complete(); + } + })) + .doOnError(error -> { + if (!isClosing) { + logger.error("Error sending message: {}", error.getMessage()); + } + }); + }).then(); + + } + + private Mono serializeMessage(final JSONRPCMessage message) { + return Mono.defer(() -> { + try { + return Mono.just(jsonMapper.writeValueAsString(message)); + } + catch (IOException e) { + return Mono.error(new McpTransportException("Failed to serialize message", e)); + } + }); + } + + private Mono> sendHttpPost(final String endpoint, final String body) { + final URI requestUri = Utils.resolveUri(baseUri, endpoint); + return Mono.deferContextual(ctx -> { + var builder = this.requestBuilder.copy() + .uri(requestUri) + .header(HttpHeaders.CONTENT_TYPE, "application/json") + .header(MCP_PROTOCOL_VERSION_HEADER_NAME, MCP_PROTOCOL_VERSION) + .POST(HttpRequest.BodyPublishers.ofString(body)); + var transportContext = ctx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY); + return Mono.from(this.httpRequestCustomizer.customize(builder, "POST", requestUri, body, transportContext)); + }).flatMap(customizedBuilder -> { + var request = customizedBuilder.build(); + return Mono.fromFuture(httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString())); + }); + } + + /** + * Gracefully closes the transport connection. + * + *

+ * Sets the closing flag and disposes of the SSE subscription. This prevents new + * messages from being sent and allows ongoing operations to complete. + * @return a Mono that completes when the closing process is initiated + */ + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> { + isClosing = true; + Disposable subscription = sseSubscription.get(); + if (subscription != null && !subscription.isDisposed()) { + subscription.dispose(); + } + }); + } + + /** + * Unmarshal data to the specified type using the configured object mapper. + * @param data the data to unmarshal + * @param typeRef the type reference for the target type + * @param the target type + * @return the unmarshalled object + */ + @Override + public T unmarshalFrom(Object data, TypeRef typeRef) { + return this.jsonMapper.convertValue(data, typeRef); + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java new file mode 100644 index 000000000..0a8dff363 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java @@ -0,0 +1,832 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.client.transport; + +import java.io.IOException; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.net.http.HttpResponse.BodyHandler; +import java.time.Duration; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.CompletionException; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.function.Function; + +import io.modelcontextprotocol.client.McpAsyncClient; +import io.modelcontextprotocol.client.transport.ResponseSubscribers.ResponseEvent; +import io.modelcontextprotocol.client.transport.customizer.McpAsyncHttpClientRequestCustomizer; +import io.modelcontextprotocol.client.transport.customizer.McpSyncHttpClientRequestCustomizer; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.TypeRef; +import io.modelcontextprotocol.spec.ClosedMcpTransportSession; +import io.modelcontextprotocol.spec.DefaultMcpTransportSession; +import io.modelcontextprotocol.spec.DefaultMcpTransportStream; +import io.modelcontextprotocol.spec.HttpHeaders; +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpTransportException; +import io.modelcontextprotocol.spec.McpTransportSession; +import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException; +import io.modelcontextprotocol.spec.McpTransportStream; +import io.modelcontextprotocol.spec.ProtocolVersions; +import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.Utils; +import org.reactivestreams.Publisher; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.Disposable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.FluxSink; +import reactor.core.publisher.Mono; +import reactor.util.function.Tuple2; +import reactor.util.function.Tuples; + +/** + * An implementation of the Streamable HTTP protocol as defined by the + * 2025-03-26 version of the MCP specification. + * + *

+ * The transport is capable of resumability and reconnects. It reacts to transport-level + * session invalidation and will propagate {@link McpTransportSessionNotFoundException + * appropriate exceptions} to the higher level abstraction layer when needed in order to + * allow proper state management. The implementation handles servers that are stateful and + * provide session meta information, but can also communicate with stateless servers that + * do not provide a session identifier and do not support SSE streams. + *

+ *

+ * This implementation does not handle backwards compatibility with the "HTTP + * with SSE" transport. In order to communicate over the phased-out + * 2024-11-05 protocol, use {@link HttpClientSseClientTransport} or + * {@code WebFluxSseClientTransport}. + *

+ * + * @author Christian Tzolov + * @see Streamable + * HTTP transport specification + */ +public class HttpClientStreamableHttpTransport implements McpClientTransport { + + private static final Logger logger = LoggerFactory.getLogger(HttpClientStreamableHttpTransport.class); + + private static final String DEFAULT_ENDPOINT = "/mcp"; + + /** + * HTTP client for sending messages to the server. Uses HTTP POST over the message + * endpoint + */ + private final HttpClient httpClient; + + /** HTTP request builder for building requests to send messages to the server */ + private final HttpRequest.Builder requestBuilder; + + /** + * Event type for JSON-RPC messages received through the SSE connection. The server + * sends messages with this event type to transmit JSON-RPC protocol data. + */ + private static final String MESSAGE_EVENT_TYPE = "message"; + + private static final String APPLICATION_JSON = "application/json"; + + private static final String TEXT_EVENT_STREAM = "text/event-stream"; + + public static int NOT_FOUND = 404; + + public static int METHOD_NOT_ALLOWED = 405; + + public static int BAD_REQUEST = 400; + + private final McpJsonMapper jsonMapper; + + private final URI baseUri; + + private final String endpoint; + + private final boolean openConnectionOnStartup; + + private final boolean resumableStreams; + + private final McpAsyncHttpClientRequestCustomizer httpRequestCustomizer; + + private final AtomicReference> activeSession = new AtomicReference<>(); + + private final AtomicReference, Mono>> handler = new AtomicReference<>(); + + private final AtomicReference> exceptionHandler = new AtomicReference<>(); + + private final List supportedProtocolVersions; + + private final String latestSupportedProtocolVersion; + + private HttpClientStreamableHttpTransport(McpJsonMapper jsonMapper, HttpClient httpClient, + HttpRequest.Builder requestBuilder, String baseUri, String endpoint, boolean resumableStreams, + boolean openConnectionOnStartup, McpAsyncHttpClientRequestCustomizer httpRequestCustomizer, + List supportedProtocolVersions) { + this.jsonMapper = jsonMapper; + this.httpClient = httpClient; + this.requestBuilder = requestBuilder; + this.baseUri = URI.create(baseUri); + this.endpoint = endpoint; + this.resumableStreams = resumableStreams; + this.openConnectionOnStartup = openConnectionOnStartup; + this.activeSession.set(createTransportSession()); + this.httpRequestCustomizer = httpRequestCustomizer; + this.supportedProtocolVersions = Collections.unmodifiableList(supportedProtocolVersions); + this.latestSupportedProtocolVersion = this.supportedProtocolVersions.stream() + .sorted(Comparator.reverseOrder()) + .findFirst() + .get(); + } + + @Override + public List protocolVersions() { + return supportedProtocolVersions; + } + + public static Builder builder(String baseUri) { + return new Builder(baseUri); + } + + @Override + public Mono connect(Function, Mono> handler) { + return Mono.deferContextual(ctx -> { + this.handler.set(handler); + if (this.openConnectionOnStartup) { + logger.debug("Eagerly opening connection on startup"); + return this.reconnect(null).onErrorComplete(t -> { + logger.warn("Eager connect failed ", t); + return true; + }).then(); + } + return Mono.empty(); + }); + } + + private McpTransportSession createTransportSession() { + Function> onClose = sessionId -> sessionId == null ? Mono.empty() + : createDelete(sessionId); + return new DefaultMcpTransportSession(onClose); + } + + private McpTransportSession createClosedSession(McpTransportSession existingSession) { + var existingSessionId = Optional.ofNullable(existingSession) + .filter(session -> !(session instanceof ClosedMcpTransportSession)) + .flatMap(McpTransportSession::sessionId) + .orElse(null); + return new ClosedMcpTransportSession<>(existingSessionId); + } + + private Publisher createDelete(String sessionId) { + + var uri = Utils.resolveUri(this.baseUri, this.endpoint); + return Mono.deferContextual(ctx -> { + var builder = this.requestBuilder.copy() + .uri(uri) + .header("Cache-Control", "no-cache") + .header(HttpHeaders.MCP_SESSION_ID, sessionId) + .header(HttpHeaders.PROTOCOL_VERSION, + ctx.getOrDefault(McpAsyncClient.NEGOTIATED_PROTOCOL_VERSION, + this.latestSupportedProtocolVersion)) + .DELETE(); + var transportContext = ctx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY); + return Mono.from(this.httpRequestCustomizer.customize(builder, "DELETE", uri, null, transportContext)); + }).flatMap(requestBuilder -> { + var request = requestBuilder.build(); + return Mono.fromFuture(() -> this.httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString())); + }).then(); + } + + @Override + public void setExceptionHandler(Consumer handler) { + logger.debug("Exception handler registered"); + this.exceptionHandler.set(handler); + } + + private void handleException(Throwable t) { + logger.debug("Handling exception for session {}", sessionIdOrPlaceholder(this.activeSession.get()), t); + if (t instanceof McpTransportSessionNotFoundException) { + McpTransportSession invalidSession = this.activeSession.getAndSet(createTransportSession()); + logger.warn("Server does not recognize session {}. Invalidating.", invalidSession.sessionId()); + invalidSession.close(); + } + Consumer handler = this.exceptionHandler.get(); + if (handler != null) { + handler.accept(t); + } + } + + @Override + public Mono closeGracefully() { + return Mono.defer(() -> { + logger.debug("Graceful close triggered"); + McpTransportSession currentSession = this.activeSession.getAndUpdate(this::createClosedSession); + if (currentSession != null) { + return Mono.from(currentSession.closeGracefully()); + } + return Mono.empty(); + }); + } + + private Mono reconnect(McpTransportStream stream) { + + return Mono.deferContextual(ctx -> { + + if (stream != null) { + logger.debug("Reconnecting stream {} with lastId {}", stream.streamId(), stream.lastId()); + } + else { + logger.debug("Reconnecting with no prior stream"); + } + + final AtomicReference disposableRef = new AtomicReference<>(); + final McpTransportSession transportSession = this.activeSession.get(); + var uri = Utils.resolveUri(this.baseUri, this.endpoint); + + Disposable connection = Mono.deferContextual(connectionCtx -> { + HttpRequest.Builder requestBuilder = this.requestBuilder.copy(); + + if (transportSession != null && transportSession.sessionId().isPresent()) { + requestBuilder = requestBuilder.header(HttpHeaders.MCP_SESSION_ID, + transportSession.sessionId().get()); + } + + if (stream != null && stream.lastId().isPresent()) { + requestBuilder = requestBuilder.header(HttpHeaders.LAST_EVENT_ID, stream.lastId().get()); + } + + var builder = requestBuilder.uri(uri) + .header(HttpHeaders.ACCEPT, TEXT_EVENT_STREAM) + .header("Cache-Control", "no-cache") + .header(HttpHeaders.PROTOCOL_VERSION, + connectionCtx.getOrDefault(McpAsyncClient.NEGOTIATED_PROTOCOL_VERSION, + this.latestSupportedProtocolVersion)) + .GET(); + var transportContext = connectionCtx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY); + return Mono.from(this.httpRequestCustomizer.customize(builder, "GET", uri, null, transportContext)); + }) + .flatMapMany( + requestBuilder -> Flux.create( + sseSink -> this.httpClient + .sendAsync(requestBuilder.build(), + responseInfo -> ResponseSubscribers.sseToBodySubscriber(responseInfo, + sseSink)) + .whenComplete((response, throwable) -> { + if (throwable != null) { + sseSink.error(throwable); + } + else { + logger.debug("SSE connection established successfully"); + } + })) + .map(responseEvent -> (ResponseSubscribers.SseResponseEvent) responseEvent) + .flatMap(responseEvent -> { + int statusCode = responseEvent.responseInfo().statusCode(); + + if (statusCode >= 200 && statusCode < 300) { + + if (MESSAGE_EVENT_TYPE.equals(responseEvent.sseEvent().event())) { + try { + // We don't support batching ATM and probably + // won't since the next version considers + // removing it. + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage( + this.jsonMapper, responseEvent.sseEvent().data()); + + Tuple2, Iterable> idWithMessages = Tuples + .of(Optional.ofNullable(responseEvent.sseEvent().id()), + List.of(message)); + + McpTransportStream sessionStream = stream != null ? stream + : new DefaultMcpTransportStream<>(this.resumableStreams, + this::reconnect); + logger.debug("Connected stream {}", sessionStream.streamId()); + + return Flux.from(sessionStream.consumeSseStream(Flux.just(idWithMessages))); + + } + catch (IOException ioException) { + return Flux.error(new McpTransportException( + "Error parsing JSON-RPC message: " + responseEvent, ioException)); + } + } + else { + logger.debug("Received SSE event with type: {}", responseEvent.sseEvent()); + return Flux.empty(); + } + } + else if (statusCode == METHOD_NOT_ALLOWED) { // NotAllowed + logger + .debug("The server does not support SSE streams, using request-response mode."); + return Flux.empty(); + } + else if (statusCode == NOT_FOUND) { + + if (transportSession != null && transportSession.sessionId().isPresent()) { + // only if the request was sent with a session id + // and the response is 404, we consider it a + // session not found error. + logger.debug("Session not found for session ID: {}", + transportSession.sessionId().get()); + String sessionIdRepresentation = sessionIdOrPlaceholder(transportSession); + McpTransportSessionNotFoundException exception = new McpTransportSessionNotFoundException( + "Session not found for session ID: " + sessionIdRepresentation); + return Flux.error(exception); + } + return Flux.error( + new McpTransportException("Server Not Found. Status code:" + statusCode + + ", response-event:" + responseEvent)); + } + else if (statusCode == BAD_REQUEST) { + if (transportSession != null && transportSession.sessionId().isPresent()) { + // only if the request was sent with a session id + // and thre response is 404, we consider it a + // session not found error. + String sessionIdRepresentation = sessionIdOrPlaceholder(transportSession); + McpTransportSessionNotFoundException exception = new McpTransportSessionNotFoundException( + "Session not found for session ID: " + sessionIdRepresentation); + return Flux.error(exception); + } + return Flux.error( + new McpTransportException("Bad Request. Status code:" + statusCode + + ", response-event:" + responseEvent)); + + } + + return Flux.error(new McpTransportException( + "Received unrecognized SSE event type: " + responseEvent.sseEvent().event())); + }).flatMap( + jsonrpcMessage -> this.handler.get().apply(Mono.just(jsonrpcMessage))) + .onErrorMap(CompletionException.class, t -> t.getCause()) + .onErrorComplete(t -> { + this.handleException(t); + return true; + }) + .doFinally(s -> { + Disposable ref = disposableRef.getAndSet(null); + if (ref != null) { + transportSession.removeConnection(ref); + } + })) + .contextWrite(ctx) + .subscribe(); + + disposableRef.set(connection); + transportSession.addConnection(connection); + return Mono.just(connection); + }); + + } + + private BodyHandler toSendMessageBodySubscriber(FluxSink sink) { + + BodyHandler responseBodyHandler = responseInfo -> { + + String contentType = responseInfo.headers().firstValue(HttpHeaders.CONTENT_TYPE).orElse("").toLowerCase(); + + if (contentType.contains(TEXT_EVENT_STREAM)) { + // For SSE streams, use line subscriber that returns Void + logger.debug("Received SSE stream response, using line subscriber"); + return ResponseSubscribers.sseToBodySubscriber(responseInfo, sink); + } + else if (contentType.contains(APPLICATION_JSON)) { + // For JSON responses and others, use string subscriber + logger.debug("Received response, using string subscriber"); + return ResponseSubscribers.aggregateBodySubscriber(responseInfo, sink); + } + + logger.debug("Received Bodyless response, using discarding subscriber"); + return ResponseSubscribers.bodilessBodySubscriber(responseInfo, sink); + }; + + return responseBodyHandler; + + } + + public String toString(McpSchema.JSONRPCMessage message) { + try { + return this.jsonMapper.writeValueAsString(message); + } + catch (IOException e) { + throw new RuntimeException("Failed to serialize JSON-RPC message", e); + } + } + + public Mono sendMessage(McpSchema.JSONRPCMessage sentMessage) { + return Mono.create(deliveredSink -> { + logger.debug("Sending message {}", sentMessage); + + final AtomicReference disposableRef = new AtomicReference<>(); + final McpTransportSession transportSession = this.activeSession.get(); + + var uri = Utils.resolveUri(this.baseUri, this.endpoint); + String jsonBody = this.toString(sentMessage); + + Disposable connection = Mono.deferContextual(ctx -> { + HttpRequest.Builder requestBuilder = this.requestBuilder.copy(); + + if (transportSession != null && transportSession.sessionId().isPresent()) { + requestBuilder = requestBuilder.header(HttpHeaders.MCP_SESSION_ID, + transportSession.sessionId().get()); + } + + var builder = requestBuilder.uri(uri) + .header(HttpHeaders.ACCEPT, APPLICATION_JSON + ", " + TEXT_EVENT_STREAM) + .header(HttpHeaders.CONTENT_TYPE, APPLICATION_JSON) + .header(HttpHeaders.CACHE_CONTROL, "no-cache") + .header(HttpHeaders.PROTOCOL_VERSION, + ctx.getOrDefault(McpAsyncClient.NEGOTIATED_PROTOCOL_VERSION, + this.latestSupportedProtocolVersion)) + .POST(HttpRequest.BodyPublishers.ofString(jsonBody)); + var transportContext = ctx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY); + return Mono + .from(this.httpRequestCustomizer.customize(builder, "POST", uri, jsonBody, transportContext)); + }).flatMapMany(requestBuilder -> Flux.create(responseEventSink -> { + + // Create the async request with proper body subscriber selection + Mono.fromFuture(this.httpClient + .sendAsync(requestBuilder.build(), this.toSendMessageBodySubscriber(responseEventSink)) + .whenComplete((response, throwable) -> { + if (throwable != null) { + responseEventSink.error(throwable); + } + else { + logger.debug("SSE connection established successfully"); + } + })).onErrorMap(CompletionException.class, t -> t.getCause()).onErrorComplete().subscribe(); + + })).flatMap(responseEvent -> { + if (transportSession.markInitialized( + responseEvent.responseInfo().headers().firstValue("mcp-session-id").orElseGet(() -> null))) { + // Once we have a session, we try to open an async stream for + // the server to send notifications and requests out-of-band. + + reconnect(null).contextWrite(deliveredSink.contextView()).subscribe(); + } + + String sessionRepresentation = sessionIdOrPlaceholder(transportSession); + + int statusCode = responseEvent.responseInfo().statusCode(); + + if (statusCode >= 200 && statusCode < 300) { + + String contentType = responseEvent.responseInfo() + .headers() + .firstValue(HttpHeaders.CONTENT_TYPE) + .orElse("") + .toLowerCase(); + + String contentLength = responseEvent.responseInfo() + .headers() + .firstValue(HttpHeaders.CONTENT_LENGTH) + .orElse(null); + + // For empty content or HTTP code 202 (ACCEPTED), assume success + if (contentType.isBlank() || "0".equals(contentLength) || statusCode == 202) { + // if (contentType.isBlank() || "0".equals(contentLength)) { + logger.debug("No body returned for POST in session {}", sessionRepresentation); + // No content type means no response body, so we can just + // return an empty stream + deliveredSink.success(); + return Flux.empty(); + } + else if (contentType.contains(TEXT_EVENT_STREAM)) { + return Flux.just(((ResponseSubscribers.SseResponseEvent) responseEvent).sseEvent()) + .flatMap(sseEvent -> { + try { + // We don't support batching ATM and probably + // won't + // since the + // next version considers removing it. + McpSchema.JSONRPCMessage message = McpSchema + .deserializeJsonRpcMessage(this.jsonMapper, sseEvent.data()); + + Tuple2, Iterable> idWithMessages = Tuples + .of(Optional.ofNullable(sseEvent.id()), List.of(message)); + + McpTransportStream sessionStream = new DefaultMcpTransportStream<>( + this.resumableStreams, this::reconnect); + + logger.debug("Connected stream {}", sessionStream.streamId()); + + deliveredSink.success(); + + return Flux.from(sessionStream.consumeSseStream(Flux.just(idWithMessages))); + } + catch (IOException ioException) { + return Flux.error(new McpTransportException( + "Error parsing JSON-RPC message: " + responseEvent, ioException)); + } + }); + } + else if (contentType.contains(APPLICATION_JSON)) { + deliveredSink.success(); + String data = ((ResponseSubscribers.AggregateResponseEvent) responseEvent).data(); + if (sentMessage instanceof McpSchema.JSONRPCNotification) { + logger.warn("Notification: {} received non-compliant response: {}", sentMessage, + Utils.hasText(data) ? data : "[empty]"); + return Mono.empty(); + } + + try { + return Mono.just(McpSchema.deserializeJsonRpcMessage(jsonMapper, data)); + } + catch (IOException e) { + return Mono.error(new McpTransportException( + "Error deserializing JSON-RPC message: " + responseEvent, e)); + } + } + logger.warn("Unknown media type {} returned for POST in session {}", contentType, + sessionRepresentation); + + return Flux.error( + new RuntimeException("Unknown media type returned: " + contentType)); + } + else if (statusCode == NOT_FOUND) { + if (transportSession != null && transportSession.sessionId().isPresent()) { + // only if the request was sent with a session id and the + // response is 404, we consider it a session not found error. + logger.debug("Session not found for session ID: {}", transportSession.sessionId().get()); + McpTransportSessionNotFoundException exception = new McpTransportSessionNotFoundException( + "Session not found for session ID: " + sessionRepresentation); + return Flux.error(exception); + } + return Flux.error(new McpTransportException( + "Server Not Found. Status code:" + statusCode + ", response-event:" + responseEvent)); + } + else if (statusCode == BAD_REQUEST) { + // Some implementations can return 400 when presented with a + // session id that it doesn't know about, so we will + // invalidate the session + // https://github.com/modelcontextprotocol/typescript-sdk/issues/389 + + if (transportSession != null && transportSession.sessionId().isPresent()) { + // only if the request was sent with a session id and the + // response is 404, we consider it a session not found error. + McpTransportSessionNotFoundException exception = new McpTransportSessionNotFoundException( + "Session not found for session ID: " + sessionRepresentation); + return Flux.error(exception); + } + return Flux.error(new McpTransportException( + "Bad Request. Status code:" + statusCode + ", response-event:" + responseEvent)); + } + + return Flux.error( + new RuntimeException("Failed to send message: " + responseEvent)); + }) + .flatMap(jsonRpcMessage -> this.handler.get().apply(Mono.just(jsonRpcMessage))) + .onErrorMap(CompletionException.class, t -> t.getCause()) + .onErrorComplete(t -> { + // handle the error first + this.handleException(t); + // inform the caller of sendMessage + deliveredSink.error(t); + return true; + }) + .doFinally(s -> { + logger.debug("SendMessage finally: {}", s); + Disposable ref = disposableRef.getAndSet(null); + if (ref != null) { + transportSession.removeConnection(ref); + } + }) + .contextWrite(deliveredSink.contextView()) + .subscribe(); + + disposableRef.set(connection); + transportSession.addConnection(connection); + }); + } + + private static String sessionIdOrPlaceholder(McpTransportSession transportSession) { + return transportSession.sessionId().orElse("[missing_session_id]"); + } + + @Override + public T unmarshalFrom(Object data, TypeRef typeRef) { + return this.jsonMapper.convertValue(data, typeRef); + } + + /** + * Builder for {@link HttpClientStreamableHttpTransport}. + */ + public static class Builder { + + private final String baseUri; + + private McpJsonMapper jsonMapper; + + private HttpClient.Builder clientBuilder = HttpClient.newBuilder().version(HttpClient.Version.HTTP_1_1); + + private String endpoint = DEFAULT_ENDPOINT; + + private boolean resumableStreams = true; + + private boolean openConnectionOnStartup = false; + + private HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(); + + private McpAsyncHttpClientRequestCustomizer httpRequestCustomizer = McpAsyncHttpClientRequestCustomizer.NOOP; + + private Duration connectTimeout = Duration.ofSeconds(10); + + private List supportedProtocolVersions = List.of(ProtocolVersions.MCP_2024_11_05, + ProtocolVersions.MCP_2025_03_26, ProtocolVersions.MCP_2025_06_18); + + /** + * Creates a new builder with the specified base URI. + * @param baseUri the base URI of the MCP server + */ + private Builder(String baseUri) { + Assert.hasText(baseUri, "baseUri must not be empty"); + this.baseUri = baseUri; + } + + /** + * Sets the HTTP client builder. + * @param clientBuilder the HTTP client builder + * @return this builder + */ + public Builder clientBuilder(HttpClient.Builder clientBuilder) { + Assert.notNull(clientBuilder, "clientBuilder must not be null"); + this.clientBuilder = clientBuilder; + return this; + } + + /** + * Customizes the HTTP client builder. + * @param clientCustomizer the consumer to customize the HTTP client builder + * @return this builder + */ + public Builder customizeClient(final Consumer clientCustomizer) { + Assert.notNull(clientCustomizer, "clientCustomizer must not be null"); + clientCustomizer.accept(clientBuilder); + return this; + } + + /** + * Sets the HTTP request builder. + * @param requestBuilder the HTTP request builder + * @return this builder + */ + public Builder requestBuilder(HttpRequest.Builder requestBuilder) { + Assert.notNull(requestBuilder, "requestBuilder must not be null"); + this.requestBuilder = requestBuilder; + return this; + } + + /** + * Customizes the HTTP client builder. + * @param requestCustomizer the consumer to customize the HTTP request builder + * @return this builder + */ + public Builder customizeRequest(final Consumer requestCustomizer) { + Assert.notNull(requestCustomizer, "requestCustomizer must not be null"); + requestCustomizer.accept(requestBuilder); + return this; + } + + /** + * Configure a custom {@link McpJsonMapper} implementation to use. + * @param jsonMapper instance to use + * @return the builder instance + */ + public Builder jsonMapper(McpJsonMapper jsonMapper) { + Assert.notNull(jsonMapper, "jsonMapper must not be null"); + this.jsonMapper = jsonMapper; + return this; + } + + /** + * Configure the endpoint to make HTTP requests against. + * @param endpoint endpoint to use + * @return the builder instance + */ + public Builder endpoint(String endpoint) { + Assert.hasText(endpoint, "endpoint must be a non-empty String"); + this.endpoint = endpoint; + return this; + } + + /** + * Configure whether to use the stream resumability feature by keeping track of + * SSE event ids. + * @param resumableStreams if {@code true} event ids will be tracked and upon + * disconnection, the last seen id will be used upon reconnection as a header to + * resume consuming messages. + * @return the builder instance + */ + public Builder resumableStreams(boolean resumableStreams) { + this.resumableStreams = resumableStreams; + return this; + } + + /** + * Configure whether the client should open an SSE connection upon startup. Not + * all servers support this (although it is in theory possible with the current + * specification), so use with caution. By default, this value is {@code false}. + * @param openConnectionOnStartup if {@code true} the {@link #connect(Function)} + * method call will try to open an SSE connection before sending any JSON-RPC + * request + * @return the builder instance + */ + public Builder openConnectionOnStartup(boolean openConnectionOnStartup) { + this.openConnectionOnStartup = openConnectionOnStartup; + return this; + } + + /** + * Sets the customizer for {@link HttpRequest.Builder}, to modify requests before + * executing them. + *

+ * This overrides the customizer from + * {@link #asyncHttpRequestCustomizer(McpAsyncHttpClientRequestCustomizer)}. + *

+ * Do NOT use a blocking {@link McpSyncHttpClientRequestCustomizer} in a + * non-blocking context. Use + * {@link #asyncHttpRequestCustomizer(McpAsyncHttpClientRequestCustomizer)} + * instead. + * @param syncHttpRequestCustomizer the request customizer + * @return this builder + */ + public Builder httpRequestCustomizer(McpSyncHttpClientRequestCustomizer syncHttpRequestCustomizer) { + this.httpRequestCustomizer = McpAsyncHttpClientRequestCustomizer.fromSync(syncHttpRequestCustomizer); + return this; + } + + /** + * Sets the customizer for {@link HttpRequest.Builder}, to modify requests before + * executing them. + *

+ * This overrides the customizer from + * {@link #httpRequestCustomizer(McpSyncHttpClientRequestCustomizer)}. + *

+ * Do NOT use a blocking implementation in a non-blocking context. + * @param asyncHttpRequestCustomizer the request customizer + * @return this builder + */ + public Builder asyncHttpRequestCustomizer(McpAsyncHttpClientRequestCustomizer asyncHttpRequestCustomizer) { + this.httpRequestCustomizer = asyncHttpRequestCustomizer; + return this; + } + + /** + * Sets the connection timeout for the HTTP client. + * @param connectTimeout the connection timeout duration + * @return this builder + */ + public Builder connectTimeout(Duration connectTimeout) { + Assert.notNull(connectTimeout, "connectTimeout must not be null"); + this.connectTimeout = connectTimeout; + return this; + } + + /** + * Sets the list of supported protocol versions used in version negotiation. By + * default, the client will send the latest of those versions in the + * {@code MCP-Protocol-Version} header. + *

+ * Setting this value only updates the values used in version negotiation, and + * does NOT impact the actual capabilities of the transport. It should only be + * used for compatibility with servers having strict requirements around the + * {@code MCP-Protocol-Version} header. + * @param supportedProtocolVersions protocol versions supported by this transport + * @return this builder + * @see version + * negotiation specification + * @see Protocol + * Version Header + */ + public Builder supportedProtocolVersions(List supportedProtocolVersions) { + Assert.notEmpty(supportedProtocolVersions, "supportedProtocolVersions must not be empty"); + this.supportedProtocolVersions = Collections.unmodifiableList(supportedProtocolVersions); + return this; + } + + /** + * Construct a fresh instance of {@link HttpClientStreamableHttpTransport} using + * the current builder configuration. + * @return a new instance of {@link HttpClientStreamableHttpTransport} + */ + public HttpClientStreamableHttpTransport build() { + HttpClient httpClient = this.clientBuilder.connectTimeout(this.connectTimeout).build(); + return new HttpClientStreamableHttpTransport(jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, + httpClient, requestBuilder, baseUri, endpoint, resumableStreams, openConnectionOnStartup, + httpRequestCustomizer, supportedProtocolVersions); + } + + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/ResponseSubscribers.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/ResponseSubscribers.java new file mode 100644 index 000000000..29dc23c35 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/ResponseSubscribers.java @@ -0,0 +1,327 @@ +/* +* Copyright 2024 - 2024 the original author or authors. +*/ + +package io.modelcontextprotocol.client.transport; + +import java.net.http.HttpResponse; +import java.net.http.HttpResponse.BodySubscriber; +import java.net.http.HttpResponse.ResponseInfo; +import java.util.concurrent.atomic.AtomicReference; +import java.util.regex.Pattern; + +import org.reactivestreams.FlowAdapters; +import org.reactivestreams.Subscription; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import io.modelcontextprotocol.spec.McpTransportException; +import reactor.core.publisher.BaseSubscriber; +import reactor.core.publisher.FluxSink; + +/** + * Utility class providing various {@link BodySubscriber} implementations for handling + * different types of HTTP response bodies in the context of Model Context Protocol (MCP) + * clients. + * + *

+ * Defines subscribers for processing Server-Sent Events (SSE), aggregate responses, and + * bodiless responses. + * + * @author Christian Tzolov + * @author Dariusz Jędrzejczyk + */ +class ResponseSubscribers { + + private static final Logger logger = LoggerFactory.getLogger(ResponseSubscribers.class); + + record SseEvent(String id, String event, String data) { + } + + sealed interface ResponseEvent permits SseResponseEvent, AggregateResponseEvent, DummyEvent { + + ResponseInfo responseInfo(); + + } + + record DummyEvent(ResponseInfo responseInfo) implements ResponseEvent { + + } + + record SseResponseEvent(ResponseInfo responseInfo, SseEvent sseEvent) implements ResponseEvent { + } + + record AggregateResponseEvent(ResponseInfo responseInfo, String data) implements ResponseEvent { + } + + static BodySubscriber sseToBodySubscriber(ResponseInfo responseInfo, FluxSink sink) { + return HttpResponse.BodySubscribers + .fromLineSubscriber(FlowAdapters.toFlowSubscriber(new SseLineSubscriber(responseInfo, sink))); + } + + static BodySubscriber aggregateBodySubscriber(ResponseInfo responseInfo, FluxSink sink) { + return HttpResponse.BodySubscribers + .fromLineSubscriber(FlowAdapters.toFlowSubscriber(new AggregateSubscriber(responseInfo, sink))); + } + + static BodySubscriber bodilessBodySubscriber(ResponseInfo responseInfo, FluxSink sink) { + return HttpResponse.BodySubscribers + .fromLineSubscriber(FlowAdapters.toFlowSubscriber(new BodilessResponseLineSubscriber(responseInfo, sink))); + } + + static class SseLineSubscriber extends BaseSubscriber { + + /** + * Pattern to extract data content from SSE "data:" lines. + */ + private static final Pattern EVENT_DATA_PATTERN = Pattern.compile("^data:(.+)$", Pattern.MULTILINE); + + /** + * Pattern to extract event ID from SSE "id:" lines. + */ + private static final Pattern EVENT_ID_PATTERN = Pattern.compile("^id:(.+)$", Pattern.MULTILINE); + + /** + * Pattern to extract event type from SSE "event:" lines. + */ + private static final Pattern EVENT_TYPE_PATTERN = Pattern.compile("^event:(.+)$", Pattern.MULTILINE); + + /** + * The sink for emitting parsed response events. + */ + private final FluxSink sink; + + /** + * StringBuilder for accumulating multi-line event data. + */ + private final StringBuilder eventBuilder; + + /** + * Current event's ID, if specified. + */ + private final AtomicReference currentEventId; + + /** + * Current event's type, if specified. + */ + private final AtomicReference currentEventType; + + /** + * The response information from the HTTP response. Send with each event to + * provide context. + */ + private ResponseInfo responseInfo; + + /** + * Creates a new LineSubscriber that will emit parsed SSE events to the provided + * sink. + * @param sink the {@link FluxSink} to emit parsed {@link ResponseEvent} objects + * to + */ + public SseLineSubscriber(ResponseInfo responseInfo, FluxSink sink) { + this.sink = sink; + this.eventBuilder = new StringBuilder(); + this.currentEventId = new AtomicReference<>(); + this.currentEventType = new AtomicReference<>(); + this.responseInfo = responseInfo; + } + + @Override + protected void hookOnSubscribe(Subscription subscription) { + + sink.onRequest(n -> { + subscription.request(n); + }); + + // Register disposal callback to cancel subscription when Flux is disposed + sink.onDispose(() -> { + subscription.cancel(); + }); + } + + @Override + protected void hookOnNext(String line) { + if (line.isEmpty()) { + // Empty line means end of event + if (this.eventBuilder.length() > 0) { + String eventData = this.eventBuilder.toString(); + SseEvent sseEvent = new SseEvent(currentEventId.get(), currentEventType.get(), eventData.trim()); + + this.sink.next(new SseResponseEvent(responseInfo, sseEvent)); + this.eventBuilder.setLength(0); + } + } + else { + if (line.startsWith("data:")) { + var matcher = EVENT_DATA_PATTERN.matcher(line); + if (matcher.find()) { + this.eventBuilder.append(matcher.group(1).trim()).append("\n"); + } + upstream().request(1); + } + else if (line.startsWith("id:")) { + var matcher = EVENT_ID_PATTERN.matcher(line); + if (matcher.find()) { + this.currentEventId.set(matcher.group(1).trim()); + } + upstream().request(1); + } + else if (line.startsWith("event:")) { + var matcher = EVENT_TYPE_PATTERN.matcher(line); + if (matcher.find()) { + this.currentEventType.set(matcher.group(1).trim()); + } + upstream().request(1); + } + else if (line.startsWith(":")) { + // Ignore comment lines starting with ":" + // This is a no-op, just to skip comments + logger.debug("Ignoring comment line: {}", line); + upstream().request(1); + } + else { + // If the response is not successful, emit an error + this.sink.error(new McpTransportException( + "Invalid SSE response. Status code: " + this.responseInfo.statusCode() + " Line: " + line)); + + } + } + } + + @Override + protected void hookOnComplete() { + if (this.eventBuilder.length() > 0) { + String eventData = this.eventBuilder.toString(); + SseEvent sseEvent = new SseEvent(currentEventId.get(), currentEventType.get(), eventData.trim()); + this.sink.next(new SseResponseEvent(responseInfo, sseEvent)); + } + this.sink.complete(); + } + + @Override + protected void hookOnError(Throwable throwable) { + this.sink.error(throwable); + } + + } + + static class AggregateSubscriber extends BaseSubscriber { + + /** + * The sink for emitting parsed response events. + */ + private final FluxSink sink; + + /** + * StringBuilder for accumulating multi-line event data. + */ + private final StringBuilder eventBuilder; + + /** + * The response information from the HTTP response. Send with each event to + * provide context. + */ + private ResponseInfo responseInfo; + + volatile boolean hasRequestedDemand = false; + + /** + * Creates a new JsonLineSubscriber that will emit parsed JSON-RPC messages. + * @param sink the {@link FluxSink} to emit parsed {@link ResponseEvent} objects + * to + */ + public AggregateSubscriber(ResponseInfo responseInfo, FluxSink sink) { + this.sink = sink; + this.eventBuilder = new StringBuilder(); + this.responseInfo = responseInfo; + } + + @Override + protected void hookOnSubscribe(Subscription subscription) { + + sink.onRequest(n -> { + if (!hasRequestedDemand) { + subscription.request(Long.MAX_VALUE); + } + hasRequestedDemand = true; + }); + + // Register disposal callback to cancel subscription when Flux is disposed + sink.onDispose(subscription::cancel); + } + + @Override + protected void hookOnNext(String line) { + this.eventBuilder.append(line).append("\n"); + } + + @Override + protected void hookOnComplete() { + + if (hasRequestedDemand) { + String data = this.eventBuilder.toString(); + this.sink.next(new AggregateResponseEvent(responseInfo, data)); + } + + this.sink.complete(); + } + + @Override + protected void hookOnError(Throwable throwable) { + this.sink.error(throwable); + } + + } + + static class BodilessResponseLineSubscriber extends BaseSubscriber { + + /** + * The sink for emitting parsed response events. + */ + private final FluxSink sink; + + private final ResponseInfo responseInfo; + + volatile boolean hasRequestedDemand = false; + + public BodilessResponseLineSubscriber(ResponseInfo responseInfo, FluxSink sink) { + this.sink = sink; + this.responseInfo = responseInfo; + } + + @Override + protected void hookOnSubscribe(Subscription subscription) { + + sink.onRequest(n -> { + if (!hasRequestedDemand) { + subscription.request(Long.MAX_VALUE); + } + hasRequestedDemand = true; + }); + + // Register disposal callback to cancel subscription when Flux is disposed + sink.onDispose(() -> { + subscription.cancel(); + }); + } + + @Override + protected void hookOnComplete() { + if (hasRequestedDemand) { + // emit dummy event to be able to inspect the response info + // this is a shortcut allowing for a more streamlined processing using + // operator composition instead of having to deal with the + // CompletableFuture along the Subscriber for inspecting the result + this.sink.next(new DummyEvent(responseInfo)); + } + this.sink.complete(); + } + + @Override + protected void hookOnError(Throwable throwable) { + this.sink.error(throwable); + } + + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/ServerParameters.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/ServerParameters.java similarity index 100% rename from mcp/src/main/java/io/modelcontextprotocol/client/transport/ServerParameters.java rename to mcp-core/src/main/java/io/modelcontextprotocol/client/transport/ServerParameters.java diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java similarity index 91% rename from mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java rename to mcp-core/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java index 8545348ed..1b4eaca97 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java @@ -15,8 +15,8 @@ import java.util.function.Consumer; import java.util.function.Function; -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.json.TypeRef; +import io.modelcontextprotocol.json.McpJsonMapper; import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; @@ -48,7 +48,7 @@ public class StdioClientTransport implements McpClientTransport { /** The server process being communicated with */ private Process process; - private ObjectMapper objectMapper; + private McpJsonMapper jsonMapper; /** Scheduler for handling inbound messages from the server process */ private Scheduler inboundScheduler; @@ -70,29 +70,20 @@ public class StdioClientTransport implements McpClientTransport { private Consumer stdErrorHandler = error -> logger.info("STDERR Message received: {}", error); /** - * Creates a new StdioClientTransport with the specified parameters and default - * ObjectMapper. + * Creates a new StdioClientTransport with the specified parameters and JsonMapper. * @param params The parameters for configuring the server process + * @param jsonMapper The JsonMapper to use for JSON serialization/deserialization */ - public StdioClientTransport(ServerParameters params) { - this(params, new ObjectMapper()); - } - - /** - * Creates a new StdioClientTransport with the specified parameters and ObjectMapper. - * @param params The parameters for configuring the server process - * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization - */ - public StdioClientTransport(ServerParameters params, ObjectMapper objectMapper) { + public StdioClientTransport(ServerParameters params, McpJsonMapper jsonMapper) { Assert.notNull(params, "The params can not be null"); - Assert.notNull(objectMapper, "The ObjectMapper can not be null"); + Assert.notNull(jsonMapper, "The JsonMapper can not be null"); this.inboundSink = Sinks.many().unicast().onBackpressureBuffer(); this.outboundSink = Sinks.many().unicast().onBackpressureBuffer(); this.params = params; - this.objectMapper = objectMapper; + this.jsonMapper = jsonMapper; this.errorSink = Sinks.many().unicast().onBackpressureBuffer(); @@ -259,7 +250,7 @@ private void startInboundProcessing() { String line; while (!isClosing && (line = processReader.readLine()) != null) { try { - JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.objectMapper, line); + JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.jsonMapper, line); if (!this.inboundSink.tryEmitNext(message).isSuccess()) { if (!isClosing) { logger.error("Failed to enqueue inbound message: {}", message); @@ -269,7 +260,7 @@ private void startInboundProcessing() { } catch (Exception e) { if (!isClosing) { - logger.error("Error processing inbound message for line: " + line, e); + logger.error("Error processing inbound message for line: {}", line, e); } break; } @@ -300,7 +291,7 @@ private void startOutboundProcessing() { .handle((message, s) -> { if (message != null && !isClosing) { try { - String jsonMessage = objectMapper.writeValueAsString(message); + String jsonMessage = jsonMapper.writeValueAsString(message); // Escape any embedded newlines in the JSON message as per spec: // https://spec.modelcontextprotocol.io/specification/basic/transports/#stdio // - Messages are delimited by newlines, and MUST NOT contain @@ -362,11 +353,11 @@ public Mono closeGracefully() { } else { logger.warn("Process not started"); - return Mono.empty(); + return Mono.empty(); } })).doOnNext(process -> { if (process.exitValue() != 0) { - logger.warn("Process terminated with code " + process.exitValue()); + logger.warn("Process terminated with code {}", process.exitValue()); } else { logger.info("MCP server process stopped"); @@ -392,8 +383,8 @@ public Sinks.Many getErrorSink() { } @Override - public T unmarshalFrom(Object data, TypeReference typeRef) { - return this.objectMapper.convertValue(data, typeRef); + public T unmarshalFrom(Object data, TypeRef typeRef) { + return this.jsonMapper.convertValue(data, typeRef); } } diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/customizer/DelegatingMcpAsyncHttpClientRequestCustomizer.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/customizer/DelegatingMcpAsyncHttpClientRequestCustomizer.java new file mode 100644 index 000000000..2492efe18 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/customizer/DelegatingMcpAsyncHttpClientRequestCustomizer.java @@ -0,0 +1,42 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ +package io.modelcontextprotocol.client.transport.customizer; + +import java.net.URI; +import java.net.http.HttpRequest; +import java.util.List; + +import org.reactivestreams.Publisher; + +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.util.Assert; + +import reactor.core.publisher.Mono; + +/** + * Composable {@link McpAsyncHttpClientRequestCustomizer} that applies multiple + * customizers, in order. + * + * @author Daniel Garnier-Moiroux + */ +public class DelegatingMcpAsyncHttpClientRequestCustomizer implements McpAsyncHttpClientRequestCustomizer { + + private final List customizers; + + public DelegatingMcpAsyncHttpClientRequestCustomizer(List customizers) { + Assert.notNull(customizers, "Customizers must not be null"); + this.customizers = customizers; + } + + @Override + public Publisher customize(HttpRequest.Builder builder, String method, URI endpoint, + String body, McpTransportContext context) { + var result = Mono.just(builder); + for (var customizer : this.customizers) { + result = result.flatMap(b -> Mono.from(customizer.customize(b, method, endpoint, body, context))); + } + return result; + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/customizer/DelegatingMcpSyncHttpClientRequestCustomizer.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/customizer/DelegatingMcpSyncHttpClientRequestCustomizer.java new file mode 100644 index 000000000..e627e7e69 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/customizer/DelegatingMcpSyncHttpClientRequestCustomizer.java @@ -0,0 +1,35 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.client.transport.customizer; + +import java.net.URI; +import java.net.http.HttpRequest; +import java.util.List; + +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.util.Assert; + +/** + * Composable {@link McpSyncHttpClientRequestCustomizer} that applies multiple + * customizers, in order. + * + * @author Daniel Garnier-Moiroux + */ +public class DelegatingMcpSyncHttpClientRequestCustomizer implements McpSyncHttpClientRequestCustomizer { + + private final List delegates; + + public DelegatingMcpSyncHttpClientRequestCustomizer(List customizers) { + Assert.notNull(customizers, "Customizers must not be null"); + this.delegates = customizers; + } + + @Override + public void customize(HttpRequest.Builder builder, String method, URI endpoint, String body, + McpTransportContext context) { + this.delegates.forEach(delegate -> delegate.customize(builder, method, endpoint, body, context)); + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/customizer/McpAsyncHttpClientRequestCustomizer.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/customizer/McpAsyncHttpClientRequestCustomizer.java new file mode 100644 index 000000000..756b39c35 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/customizer/McpAsyncHttpClientRequestCustomizer.java @@ -0,0 +1,55 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.client.transport.customizer; + +import java.net.URI; +import java.net.http.HttpRequest; + +import org.reactivestreams.Publisher; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; +import reactor.util.annotation.Nullable; + +import io.modelcontextprotocol.common.McpTransportContext; + +/** + * Customize {@link HttpRequest.Builder} before executing the request, in either SSE or + * Streamable HTTP transport. + *

+ * When used in a non-blocking context, implementations MUST be non-blocking. + * + * @author Daniel Garnier-Moiroux + */ +public interface McpAsyncHttpClientRequestCustomizer { + + Publisher customize(HttpRequest.Builder builder, String method, URI endpoint, + @Nullable String body, McpTransportContext context); + + McpAsyncHttpClientRequestCustomizer NOOP = new Noop(); + + /** + * Wrap a sync implementation in an async wrapper. + *

+ * Do NOT wrap a blocking implementation for use in a non-blocking context. For a + * blocking implementation, consider using {@link Schedulers#boundedElastic()}. + */ + static McpAsyncHttpClientRequestCustomizer fromSync(McpSyncHttpClientRequestCustomizer customizer) { + return (builder, method, uri, body, context) -> Mono.fromSupplier(() -> { + customizer.customize(builder, method, uri, body, context); + return builder; + }); + } + + class Noop implements McpAsyncHttpClientRequestCustomizer { + + @Override + public Publisher customize(HttpRequest.Builder builder, String method, URI endpoint, + String body, McpTransportContext context) { + return Mono.just(builder); + } + + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/customizer/McpSyncHttpClientRequestCustomizer.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/customizer/McpSyncHttpClientRequestCustomizer.java new file mode 100644 index 000000000..e22e3aa62 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/customizer/McpSyncHttpClientRequestCustomizer.java @@ -0,0 +1,28 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.client.transport.customizer; + +import java.net.URI; +import java.net.http.HttpRequest; + +import reactor.util.annotation.Nullable; + +import io.modelcontextprotocol.client.McpClient.SyncSpec; +import io.modelcontextprotocol.common.McpTransportContext; + +/** + * Customize {@link HttpRequest.Builder} before executing the request, either in SSE or + * Streamable HTTP transport. Do not rely on thread-locals in this implementation, instead + * use {@link SyncSpec#transportContextProvider} to extract context, and then consume it + * through {@link McpTransportContext}. + * + * @author Daniel Garnier-Moiroux + */ +public interface McpSyncHttpClientRequestCustomizer { + + void customize(HttpRequest.Builder builder, String method, URI endpoint, @Nullable String body, + McpTransportContext context); + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/common/DefaultMcpTransportContext.java b/mcp-core/src/main/java/io/modelcontextprotocol/common/DefaultMcpTransportContext.java new file mode 100644 index 000000000..cde637b15 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/common/DefaultMcpTransportContext.java @@ -0,0 +1,45 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.common; + +import java.util.Map; + +import io.modelcontextprotocol.util.Assert; + +/** + * Default implementation for {@link McpTransportContext} which uses a map as storage. + * + * @author Dariusz Jędrzejczyk + * @author Daniel Garnier-Moiroux + */ +class DefaultMcpTransportContext implements McpTransportContext { + + private final Map metadata; + + DefaultMcpTransportContext(Map metadata) { + Assert.notNull(metadata, "The metadata cannot be null"); + this.metadata = metadata; + } + + @Override + public Object get(String key) { + return this.metadata.get(key); + } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) + return false; + + DefaultMcpTransportContext that = (DefaultMcpTransportContext) o; + return this.metadata.equals(that.metadata); + } + + @Override + public int hashCode() { + return this.metadata.hashCode(); + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/common/McpTransportContext.java b/mcp-core/src/main/java/io/modelcontextprotocol/common/McpTransportContext.java new file mode 100644 index 000000000..46a2ccf84 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/common/McpTransportContext.java @@ -0,0 +1,46 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.common; + +import java.util.Collections; +import java.util.Map; + +/** + * Context associated with the transport layer. It allows to add transport-level metadata + * for use further down the line. Specifically, it can be beneficial to extract HTTP + * request metadata for use in MCP feature implementations. + * + * @author Dariusz Jędrzejczyk + */ +public interface McpTransportContext { + + /** + * Key for use in Reactor Context to transport the context to user land. + */ + String KEY = "MCP_TRANSPORT_CONTEXT"; + + /** + * An empty, unmodifiable context. + */ + @SuppressWarnings("unchecked") + McpTransportContext EMPTY = new DefaultMcpTransportContext(Collections.EMPTY_MAP); + + /** + * Create an unmodifiable context containing the given metadata. + * @param metadata the transport metadata + * @return the context containing the metadata + */ + static McpTransportContext create(Map metadata) { + return new DefaultMcpTransportContext(metadata); + } + + /** + * Extract a value from the context. + * @param key the key under the data is expected + * @return the associated value or {@code null} if missing. + */ + Object get(String key); + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/DefaultMcpStatelessServerHandler.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/DefaultMcpStatelessServerHandler.java new file mode 100644 index 000000000..d1b55f594 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/DefaultMcpStatelessServerHandler.java @@ -0,0 +1,63 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Mono; + +import java.util.Map; + +class DefaultMcpStatelessServerHandler implements McpStatelessServerHandler { + + private static final Logger logger = LoggerFactory.getLogger(DefaultMcpStatelessServerHandler.class); + + Map> requestHandlers; + + Map notificationHandlers; + + public DefaultMcpStatelessServerHandler(Map> requestHandlers, + Map notificationHandlers) { + this.requestHandlers = requestHandlers; + this.notificationHandlers = notificationHandlers; + } + + @Override + public Mono handleRequest(McpTransportContext transportContext, + McpSchema.JSONRPCRequest request) { + McpStatelessRequestHandler requestHandler = this.requestHandlers.get(request.method()); + if (requestHandler == null) { + return Mono.error(new McpError("Missing handler for request type: " + request.method())); + } + return requestHandler.handle(transportContext, request.params()) + .map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), result, null)) + .onErrorResume(t -> { + McpSchema.JSONRPCResponse.JSONRPCError error; + if (t instanceof McpError mcpError && mcpError.getJsonRpcError() != null) { + error = mcpError.getJsonRpcError(); + } + else { + error = new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, + t.getMessage(), null); + } + return Mono.just(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, error)); + }); + } + + @Override + public Mono handleNotification(McpTransportContext transportContext, + McpSchema.JSONRPCNotification notification) { + McpStatelessNotificationHandler notificationHandler = this.notificationHandlers.get(notification.method()); + if (notificationHandler == null) { + logger.warn("Missing handler for notification type: {}", notification.method()); + return Mono.empty(); + } + return notificationHandler.handle(transportContext, notification.params()); + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java new file mode 100644 index 000000000..23285d514 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -0,0 +1,1076 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import java.time.Duration; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.function.BiFunction; + +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.TypeRef; +import io.modelcontextprotocol.json.schema.JsonSchemaValidator; +import io.modelcontextprotocol.spec.DefaultMcpStreamableServerSessionFactory; +import io.modelcontextprotocol.spec.McpClientSession; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.CompleteResult.CompleteCompletion; +import io.modelcontextprotocol.spec.McpSchema.ErrorCodes; +import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; +import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; +import io.modelcontextprotocol.spec.McpSchema.PromptReference; +import io.modelcontextprotocol.spec.McpSchema.ResourceReference; +import io.modelcontextprotocol.spec.McpSchema.SetLevelRequest; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.spec.McpServerTransportProviderBase; +import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; +import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.DefaultMcpUriTemplateManagerFactory; +import io.modelcontextprotocol.util.McpUriTemplateManagerFactory; +import io.modelcontextprotocol.util.Utils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import static io.modelcontextprotocol.spec.McpError.RESOURCE_NOT_FOUND; + +/** + * The Model Context Protocol (MCP) server implementation that provides asynchronous + * communication using Project Reactor's Mono and Flux types. + * + *

+ * This server implements the MCP specification, enabling AI models to expose tools, + * resources, and prompts through a standardized interface. Key features include: + *

    + *
  • Asynchronous communication using reactive programming patterns + *
  • Dynamic tool registration and management + *
  • Resource handling with URI-based addressing + *
  • Prompt template management + *
  • Real-time client notifications for state changes + *
  • Structured logging with configurable severity levels + *
  • Support for client-side AI model sampling + *
+ * + *

+ * The server follows a lifecycle: + *

    + *
  1. Initialization - Accepts client connections and negotiates capabilities + *
  2. Normal Operation - Handles client requests and sends notifications + *
  3. Graceful Shutdown - Ensures clean connection termination + *
+ * + *

+ * This implementation uses Project Reactor for non-blocking operations, making it + * suitable for high-throughput scenarios and reactive applications. All operations return + * Mono or Flux types that can be composed into reactive pipelines. + * + *

+ * The server supports runtime modification of its capabilities through methods like + * {@link #addTool}, {@link #addResource}, and {@link #addPrompt}, automatically notifying + * connected clients of changes when configured to do so. + * + * @author Christian Tzolov + * @author Dariusz Jędrzejczyk + * @author Jihoon Kim + * @see McpServer + * @see McpSchema + * @see McpClientSession + */ +public class McpAsyncServer { + + private static final Logger logger = LoggerFactory.getLogger(McpAsyncServer.class); + + private final McpServerTransportProviderBase mcpTransportProvider; + + private final McpJsonMapper jsonMapper; + + private final JsonSchemaValidator jsonSchemaValidator; + + private final McpSchema.ServerCapabilities serverCapabilities; + + private final McpSchema.Implementation serverInfo; + + private final String instructions; + + private final CopyOnWriteArrayList tools = new CopyOnWriteArrayList<>(); + + private final ConcurrentHashMap resources = new ConcurrentHashMap<>(); + + private final ConcurrentHashMap resourceTemplates = new ConcurrentHashMap<>(); + + private final ConcurrentHashMap prompts = new ConcurrentHashMap<>(); + + // FIXME: this field is deprecated and should be remvoed together with the + // broadcasting loggingNotification. + private LoggingLevel minLoggingLevel = LoggingLevel.DEBUG; + + private final ConcurrentHashMap completions = new ConcurrentHashMap<>(); + + private List protocolVersions; + + private McpUriTemplateManagerFactory uriTemplateManagerFactory = new DefaultMcpUriTemplateManagerFactory(); + + /** + * Create a new McpAsyncServer with the given transport provider and capabilities. + * @param mcpTransportProvider The transport layer implementation for MCP + * communication. + * @param features The MCP server supported features. + * @param jsonMapper The JsonMapper to use for JSON serialization/deserialization + */ + McpAsyncServer(McpServerTransportProvider mcpTransportProvider, McpJsonMapper jsonMapper, + McpServerFeatures.Async features, Duration requestTimeout, + McpUriTemplateManagerFactory uriTemplateManagerFactory, JsonSchemaValidator jsonSchemaValidator) { + this.mcpTransportProvider = mcpTransportProvider; + this.jsonMapper = jsonMapper; + this.serverInfo = features.serverInfo(); + this.serverCapabilities = features.serverCapabilities().mutate().logging().build(); + this.instructions = features.instructions(); + this.tools.addAll(withStructuredOutputHandling(jsonSchemaValidator, features.tools())); + this.resources.putAll(features.resources()); + this.resourceTemplates.putAll(features.resourceTemplates()); + this.prompts.putAll(features.prompts()); + this.completions.putAll(features.completions()); + this.uriTemplateManagerFactory = uriTemplateManagerFactory; + this.jsonSchemaValidator = jsonSchemaValidator; + + Map> requestHandlers = prepareRequestHandlers(); + Map notificationHandlers = prepareNotificationHandlers(features); + + this.protocolVersions = mcpTransportProvider.protocolVersions(); + + mcpTransportProvider.setSessionFactory(transport -> new McpServerSession(UUID.randomUUID().toString(), + requestTimeout, transport, this::asyncInitializeRequestHandler, requestHandlers, notificationHandlers)); + } + + McpAsyncServer(McpStreamableServerTransportProvider mcpTransportProvider, McpJsonMapper jsonMapper, + McpServerFeatures.Async features, Duration requestTimeout, + McpUriTemplateManagerFactory uriTemplateManagerFactory, JsonSchemaValidator jsonSchemaValidator) { + this.mcpTransportProvider = mcpTransportProvider; + this.jsonMapper = jsonMapper; + this.serverInfo = features.serverInfo(); + this.serverCapabilities = features.serverCapabilities().mutate().logging().build(); + this.instructions = features.instructions(); + this.tools.addAll(withStructuredOutputHandling(jsonSchemaValidator, features.tools())); + this.resources.putAll(features.resources()); + this.resourceTemplates.putAll(features.resourceTemplates()); + this.prompts.putAll(features.prompts()); + this.completions.putAll(features.completions()); + this.uriTemplateManagerFactory = uriTemplateManagerFactory; + this.jsonSchemaValidator = jsonSchemaValidator; + + Map> requestHandlers = prepareRequestHandlers(); + Map notificationHandlers = prepareNotificationHandlers(features); + + this.protocolVersions = mcpTransportProvider.protocolVersions(); + + mcpTransportProvider.setSessionFactory(new DefaultMcpStreamableServerSessionFactory(requestTimeout, + this::asyncInitializeRequestHandler, requestHandlers, notificationHandlers)); + } + + private Map prepareNotificationHandlers(McpServerFeatures.Async features) { + Map notificationHandlers = new HashMap<>(); + + notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_INITIALIZED, (exchange, params) -> Mono.empty()); + + List, Mono>> rootsChangeConsumers = features + .rootsChangeConsumers(); + + if (Utils.isEmpty(rootsChangeConsumers)) { + rootsChangeConsumers = List.of((exchange, roots) -> Mono.fromRunnable(() -> logger + .warn("Roots list changed notification, but no consumers provided. Roots list changed: {}", roots))); + } + + notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED, + asyncRootsListChangedNotificationHandler(rootsChangeConsumers)); + return notificationHandlers; + } + + private Map> prepareRequestHandlers() { + Map> requestHandlers = new HashMap<>(); + + // Initialize request handlers for standard MCP methods + + // Ping MUST respond with an empty data, but not NULL response. + requestHandlers.put(McpSchema.METHOD_PING, (exchange, params) -> Mono.just(Map.of())); + + // Add tools API handlers if the tool capability is enabled + if (this.serverCapabilities.tools() != null) { + requestHandlers.put(McpSchema.METHOD_TOOLS_LIST, toolsListRequestHandler()); + requestHandlers.put(McpSchema.METHOD_TOOLS_CALL, toolsCallRequestHandler()); + } + + // Add resources API handlers if provided + if (this.serverCapabilities.resources() != null) { + requestHandlers.put(McpSchema.METHOD_RESOURCES_LIST, resourcesListRequestHandler()); + requestHandlers.put(McpSchema.METHOD_RESOURCES_READ, resourcesReadRequestHandler()); + requestHandlers.put(McpSchema.METHOD_RESOURCES_TEMPLATES_LIST, resourceTemplateListRequestHandler()); + } + + // Add prompts API handlers if provider exists + if (this.serverCapabilities.prompts() != null) { + requestHandlers.put(McpSchema.METHOD_PROMPT_LIST, promptsListRequestHandler()); + requestHandlers.put(McpSchema.METHOD_PROMPT_GET, promptsGetRequestHandler()); + } + + // Add logging API handlers if the logging capability is enabled + if (this.serverCapabilities.logging() != null) { + requestHandlers.put(McpSchema.METHOD_LOGGING_SET_LEVEL, setLoggerRequestHandler()); + } + + // Add completion API handlers if the completion capability is enabled + if (this.serverCapabilities.completions() != null) { + requestHandlers.put(McpSchema.METHOD_COMPLETION_COMPLETE, completionCompleteRequestHandler()); + } + return requestHandlers; + } + + // --------------------------------------- + // Lifecycle Management + // --------------------------------------- + private Mono asyncInitializeRequestHandler( + McpSchema.InitializeRequest initializeRequest) { + return Mono.defer(() -> { + logger.info("Client initialize request - Protocol: {}, Capabilities: {}, Info: {}", + initializeRequest.protocolVersion(), initializeRequest.capabilities(), + initializeRequest.clientInfo()); + + // The server MUST respond with the highest protocol version it supports + // if + // it does not support the requested (e.g. Client) version. + String serverProtocolVersion = this.protocolVersions.get(this.protocolVersions.size() - 1); + + if (this.protocolVersions.contains(initializeRequest.protocolVersion())) { + // If the server supports the requested protocol version, it MUST + // respond + // with the same version. + serverProtocolVersion = initializeRequest.protocolVersion(); + } + else { + logger.warn( + "Client requested unsupported protocol version: {}, so the server will suggest the {} version instead", + initializeRequest.protocolVersion(), serverProtocolVersion); + } + + return Mono.just(new McpSchema.InitializeResult(serverProtocolVersion, this.serverCapabilities, + this.serverInfo, this.instructions)); + }); + } + + /** + * Get the server capabilities that define the supported features and functionality. + * @return The server capabilities + */ + public McpSchema.ServerCapabilities getServerCapabilities() { + return this.serverCapabilities; + } + + /** + * Get the server implementation information. + * @return The server implementation details + */ + public McpSchema.Implementation getServerInfo() { + return this.serverInfo; + } + + /** + * Gracefully closes the server, allowing any in-progress operations to complete. + * @return A Mono that completes when the server has been closed + */ + public Mono closeGracefully() { + return this.mcpTransportProvider.closeGracefully(); + } + + /** + * Close the server immediately. + */ + public void close() { + this.mcpTransportProvider.close(); + } + + private McpNotificationHandler asyncRootsListChangedNotificationHandler( + List, Mono>> rootsChangeConsumers) { + return (exchange, params) -> exchange.listRoots() + .flatMap(listRootsResult -> Flux.fromIterable(rootsChangeConsumers) + .flatMap(consumer -> Mono.defer(() -> consumer.apply(exchange, listRootsResult.roots()))) + .onErrorResume(error -> { + logger.error("Error handling roots list change notification", error); + return Mono.empty(); + }) + .then()); + } + + // --------------------------------------- + // Tool Management + // --------------------------------------- + + /** + * Add a new tool call specification at runtime. + * @param toolSpecification The tool specification to add + * @return Mono that completes when clients have been notified of the change + */ + public Mono addTool(McpServerFeatures.AsyncToolSpecification toolSpecification) { + if (toolSpecification == null) { + return Mono.error(new IllegalArgumentException("Tool specification must not be null")); + } + if (toolSpecification.tool() == null) { + return Mono.error(new IllegalArgumentException("Tool must not be null")); + } + if (toolSpecification.call() == null && toolSpecification.callHandler() == null) { + return Mono.error(new IllegalArgumentException("Tool call handler must not be null")); + } + if (this.serverCapabilities.tools() == null) { + return Mono.error(new IllegalStateException("Server must be configured with tool capabilities")); + } + + var wrappedToolSpecification = withStructuredOutputHandling(this.jsonSchemaValidator, toolSpecification); + + return Mono.defer(() -> { + // Remove tools with duplicate tool names first + if (this.tools.removeIf(th -> th.tool().name().equals(wrappedToolSpecification.tool().name()))) { + logger.warn("Replace existing Tool with name '{}'", wrappedToolSpecification.tool().name()); + } + + this.tools.add(wrappedToolSpecification); + logger.debug("Added tool handler: {}", wrappedToolSpecification.tool().name()); + + if (this.serverCapabilities.tools().listChanged()) { + return notifyToolsListChanged(); + } + return Mono.empty(); + }); + } + + private static class StructuredOutputCallToolHandler + implements BiFunction> { + + private final BiFunction> delegateCallToolResult; + + private final JsonSchemaValidator jsonSchemaValidator; + + private final Map outputSchema; + + public StructuredOutputCallToolHandler(JsonSchemaValidator jsonSchemaValidator, + Map outputSchema, + BiFunction> delegateHandler) { + + Assert.notNull(jsonSchemaValidator, "JsonSchemaValidator must not be null"); + Assert.notNull(delegateHandler, "Delegate call tool result handler must not be null"); + + this.delegateCallToolResult = delegateHandler; + this.outputSchema = outputSchema; + this.jsonSchemaValidator = jsonSchemaValidator; + } + + @Override + public Mono apply(McpAsyncServerExchange exchange, McpSchema.CallToolRequest request) { + + return this.delegateCallToolResult.apply(exchange, request).map(result -> { + + if (Boolean.TRUE.equals(result.isError())) { + // If the tool call resulted in an error, skip further validation + return result; + } + + if (outputSchema == null) { + if (result.structuredContent() != null) { + logger.warn( + "Tool call with no outputSchema is not expected to have a result with structured content, but got: {}", + result.structuredContent()); + } + // Pass through. No validation is required if no output schema is + // provided. + return result; + } + + // If an output schema is provided, servers MUST provide structured + // results that conform to this schema. + // https://modelcontextprotocol.io/specification/2025-06-18/server/tools#output-schema + if (result.structuredContent() == null) { + String content = "Response missing structured content which is expected when calling tool with non-empty outputSchema"; + logger.warn(content); + return CallToolResult.builder() + .content(List.of(new McpSchema.TextContent(content))) + .isError(true) + .build(); + } + + // Validate the result against the output schema + var validation = this.jsonSchemaValidator.validate(outputSchema, result.structuredContent()); + + if (!validation.valid()) { + logger.warn("Tool call result validation failed: {}", validation.errorMessage()); + return CallToolResult.builder() + .content(List.of(new McpSchema.TextContent(validation.errorMessage()))) + .isError(true) + .build(); + } + + if (Utils.isEmpty(result.content())) { + // For backwards compatibility, a tool that returns structured + // content SHOULD also return functionally equivalent unstructured + // content. (For example, serialized JSON can be returned in a + // TextContent block.) + // https://modelcontextprotocol.io/specification/2025-06-18/server/tools#structured-content + + return CallToolResult.builder() + .content(List.of(new McpSchema.TextContent(validation.jsonStructuredOutput()))) + .isError(result.isError()) + .structuredContent(result.structuredContent()) + .build(); + } + + return result; + }); + } + + } + + private static List withStructuredOutputHandling( + JsonSchemaValidator jsonSchemaValidator, List tools) { + + if (Utils.isEmpty(tools)) { + return tools; + } + + return tools.stream().map(tool -> withStructuredOutputHandling(jsonSchemaValidator, tool)).toList(); + } + + private static McpServerFeatures.AsyncToolSpecification withStructuredOutputHandling( + JsonSchemaValidator jsonSchemaValidator, McpServerFeatures.AsyncToolSpecification toolSpecification) { + + if (toolSpecification.callHandler() instanceof StructuredOutputCallToolHandler) { + // If the tool is already wrapped, return it as is + return toolSpecification; + } + + if (toolSpecification.tool().outputSchema() == null) { + // If the tool does not have an output schema, return it as is + return toolSpecification; + } + + return McpServerFeatures.AsyncToolSpecification.builder() + .tool(toolSpecification.tool()) + .callHandler(new StructuredOutputCallToolHandler(jsonSchemaValidator, + toolSpecification.tool().outputSchema(), toolSpecification.callHandler())) + .build(); + } + + /** + * List all registered tools. + * @return A Flux stream of all registered tools + */ + public Flux listTools() { + return Flux.fromIterable(this.tools).map(McpServerFeatures.AsyncToolSpecification::tool); + } + + /** + * Remove a tool handler at runtime. + * @param toolName The name of the tool handler to remove + * @return Mono that completes when clients have been notified of the change + */ + public Mono removeTool(String toolName) { + if (toolName == null) { + return Mono.error(new IllegalArgumentException("Tool name must not be null")); + } + if (this.serverCapabilities.tools() == null) { + return Mono.error(new IllegalStateException("Server must be configured with tool capabilities")); + } + + return Mono.defer(() -> { + if (this.tools.removeIf(toolSpecification -> toolSpecification.tool().name().equals(toolName))) { + + logger.debug("Removed tool handler: {}", toolName); + if (this.serverCapabilities.tools().listChanged()) { + return notifyToolsListChanged(); + } + } + else { + logger.warn("Ignore as a Tool with name '{}' not found", toolName); + } + + return Mono.empty(); + }); + } + + /** + * Notifies clients that the list of available tools has changed. + * @return A Mono that completes when all clients have been notified + */ + public Mono notifyToolsListChanged() { + return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_TOOLS_LIST_CHANGED, null); + } + + private McpRequestHandler toolsListRequestHandler() { + return (exchange, params) -> { + List tools = this.tools.stream().map(McpServerFeatures.AsyncToolSpecification::tool).toList(); + + return Mono.just(new McpSchema.ListToolsResult(tools, null)); + }; + } + + private McpRequestHandler toolsCallRequestHandler() { + return (exchange, params) -> { + McpSchema.CallToolRequest callToolRequest = jsonMapper.convertValue(params, + new TypeRef() { + }); + + Optional toolSpecification = this.tools.stream() + .filter(tr -> callToolRequest.name().equals(tr.tool().name())) + .findAny(); + + if (toolSpecification.isEmpty()) { + return Mono.error(McpError.builder(McpSchema.ErrorCodes.INVALID_PARAMS) + .message("Unknown tool: invalid_tool_name") + .data("Tool not found: " + callToolRequest.name()) + .build()); + } + + return toolSpecification.get().callHandler().apply(exchange, callToolRequest); + }; + } + + // --------------------------------------- + // Resource Management + // --------------------------------------- + + /** + * Add a new resource handler at runtime. + * @param resourceSpecification The resource handler to add + * @return Mono that completes when clients have been notified of the change + */ + public Mono addResource(McpServerFeatures.AsyncResourceSpecification resourceSpecification) { + if (resourceSpecification == null || resourceSpecification.resource() == null) { + return Mono.error(new IllegalArgumentException("Resource must not be null")); + } + + if (this.serverCapabilities.resources() == null) { + return Mono.error(new IllegalStateException( + "Server must be configured with resource capabilities to allow adding resources")); + } + + return Mono.defer(() -> { + var previous = this.resources.put(resourceSpecification.resource().uri(), resourceSpecification); + if (previous != null) { + logger.warn("Replace existing Resource with URI '{}'", resourceSpecification.resource().uri()); + } + else { + logger.debug("Added resource handler: {}", resourceSpecification.resource().uri()); + } + if (this.serverCapabilities.resources().listChanged()) { + return notifyResourcesListChanged(); + } + return Mono.empty(); + }); + } + + /** + * List all registered resources. + * @return A Flux stream of all registered resources + */ + public Flux listResources() { + return Flux.fromIterable(this.resources.values()).map(McpServerFeatures.AsyncResourceSpecification::resource); + } + + /** + * Remove a resource handler at runtime. + * @param resourceUri The URI of the resource handler to remove + * @return Mono that completes when clients have been notified of the change + */ + public Mono removeResource(String resourceUri) { + if (resourceUri == null) { + return Mono.error(new IllegalArgumentException("Resource URI must not be null")); + } + if (this.serverCapabilities.resources() == null) { + return Mono.error(new IllegalStateException( + "Server must be configured with resource capabilities to allow removing resources")); + } + + return Mono.defer(() -> { + McpServerFeatures.AsyncResourceSpecification removed = this.resources.remove(resourceUri); + if (removed != null) { + logger.debug("Removed resource handler: {}", resourceUri); + if (this.serverCapabilities.resources().listChanged()) { + return notifyResourcesListChanged(); + } + return Mono.empty(); + } + else { + logger.warn("Ignore as a Resource with URI '{}' not found", resourceUri); + } + return Mono.empty(); + }); + } + + /** + * Add a new resource template at runtime. + * @param resourceTemplateSpecification The resource template to add + * @return Mono that completes when clients have been notified of the change + */ + public Mono addResourceTemplate( + McpServerFeatures.AsyncResourceTemplateSpecification resourceTemplateSpecification) { + + if (this.serverCapabilities.resources() == null) { + return Mono.error(new IllegalStateException( + "Server must be configured with resource capabilities to allow adding resource templates")); + } + + return Mono.defer(() -> { + var previous = this.resourceTemplates.put(resourceTemplateSpecification.resourceTemplate().uriTemplate(), + resourceTemplateSpecification); + if (previous != null) { + logger.warn("Replace existing Resource Template with URI '{}'", + resourceTemplateSpecification.resourceTemplate().uriTemplate()); + } + else { + logger.debug("Added resource template handler: {}", + resourceTemplateSpecification.resourceTemplate().uriTemplate()); + } + if (this.serverCapabilities.resources().listChanged()) { + return notifyResourcesListChanged(); + } + return Mono.empty(); + }); + } + + /** + * List all registered resource templates. + * @return A Flux stream of all registered resource templates + */ + public Flux listResourceTemplates() { + return Flux.fromIterable(this.resourceTemplates.values()) + .map(McpServerFeatures.AsyncResourceTemplateSpecification::resourceTemplate); + } + + /** + * Remove a resource template at runtime. + * @param uriTemplate The URI template of the resource template to remove + * @return Mono that completes when clients have been notified of the change + */ + public Mono removeResourceTemplate(String uriTemplate) { + + if (this.serverCapabilities.resources() == null) { + return Mono.error(new IllegalStateException( + "Server must be configured with resource capabilities to allow removing resource templates")); + } + + return Mono.defer(() -> { + McpServerFeatures.AsyncResourceTemplateSpecification removed = this.resourceTemplates.remove(uriTemplate); + if (removed != null) { + logger.debug("Removed resource template: {}", uriTemplate); + } + else { + logger.warn("Ignore as a Resource Template with URI '{}' not found", uriTemplate); + } + return Mono.empty(); + }); + } + + /** + * Notifies clients that the list of available resources has changed. + * @return A Mono that completes when all clients have been notified + */ + public Mono notifyResourcesListChanged() { + return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_RESOURCES_LIST_CHANGED, null); + } + + /** + * Notifies clients that the resources have updated. + * @return A Mono that completes when all clients have been notified + */ + public Mono notifyResourcesUpdated(McpSchema.ResourcesUpdatedNotification resourcesUpdatedNotification) { + return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_RESOURCES_UPDATED, + resourcesUpdatedNotification); + } + + private McpRequestHandler resourcesListRequestHandler() { + return (exchange, params) -> { + var resourceList = this.resources.values() + .stream() + .map(McpServerFeatures.AsyncResourceSpecification::resource) + .toList(); + return Mono.just(new McpSchema.ListResourcesResult(resourceList, null)); + }; + } + + private McpRequestHandler resourceTemplateListRequestHandler() { + return (exchange, params) -> { + var resourceList = this.resourceTemplates.values() + .stream() + .map(McpServerFeatures.AsyncResourceTemplateSpecification::resourceTemplate) + .toList(); + return Mono.just(new McpSchema.ListResourceTemplatesResult(resourceList, null)); + }; + } + + private McpRequestHandler resourcesReadRequestHandler() { + return (ex, params) -> { + McpSchema.ReadResourceRequest resourceRequest = jsonMapper.convertValue(params, new TypeRef<>() { + }); + + var resourceUri = resourceRequest.uri(); + + // First try to find a static resource specification + // Static resources have exact URIs + return this.findResourceSpecification(resourceUri) + .map(spec -> spec.readHandler().apply(ex, resourceRequest)) + .orElseGet(() -> { + // If not found, try to find a dynamic resource specification + // Dynamic resources have URI templates + return this.findResourceTemplateSpecification(resourceUri) + .map(spec -> spec.readHandler().apply(ex, resourceRequest)) + .orElseGet(() -> Mono.error(RESOURCE_NOT_FOUND.apply(resourceUri))); + }); + }; + } + + private Optional findResourceSpecification(String uri) { + var result = this.resources.values() + .stream() + .filter(spec -> this.uriTemplateManagerFactory.create(spec.resource().uri()).matches(uri)) + .findFirst(); + return result; + } + + private Optional findResourceTemplateSpecification( + String uri) { + return this.resourceTemplates.values() + .stream() + .filter(spec -> this.uriTemplateManagerFactory.create(spec.resourceTemplate().uriTemplate()).matches(uri)) + .findFirst(); + } + + // --------------------------------------- + // Prompt Management + // --------------------------------------- + + /** + * Add a new prompt handler at runtime. + * @param promptSpecification The prompt handler to add + * @return Mono that completes when clients have been notified of the change + */ + public Mono addPrompt(McpServerFeatures.AsyncPromptSpecification promptSpecification) { + if (promptSpecification == null) { + return Mono.error(new IllegalArgumentException("Prompt specification must not be null")); + } + if (this.serverCapabilities.prompts() == null) { + return Mono.error(new IllegalStateException("Server must be configured with prompt capabilities")); + } + + return Mono.defer(() -> { + var previous = this.prompts.put(promptSpecification.prompt().name(), promptSpecification); + if (previous != null) { + logger.warn("Replace existing Prompt with name '{}'", promptSpecification.prompt().name()); + } + else { + logger.debug("Added prompt handler: {}", promptSpecification.prompt().name()); + } + if (this.serverCapabilities.prompts().listChanged()) { + return this.notifyPromptsListChanged(); + } + + return Mono.empty(); + }); + } + + /** + * List all registered prompts. + * @return A Flux stream of all registered prompts + */ + public Flux listPrompts() { + return Flux.fromIterable(this.prompts.values()).map(McpServerFeatures.AsyncPromptSpecification::prompt); + } + + /** + * Remove a prompt handler at runtime. + * @param promptName The name of the prompt handler to remove + * @return Mono that completes when clients have been notified of the change + */ + public Mono removePrompt(String promptName) { + if (promptName == null) { + return Mono.error(new IllegalArgumentException("Prompt name must not be null")); + } + if (this.serverCapabilities.prompts() == null) { + return Mono.error(new IllegalStateException("Server must be configured with prompt capabilities")); + } + + return Mono.defer(() -> { + McpServerFeatures.AsyncPromptSpecification removed = this.prompts.remove(promptName); + + if (removed != null) { + logger.debug("Removed prompt handler: {}", promptName); + if (this.serverCapabilities.prompts().listChanged()) { + return this.notifyPromptsListChanged(); + } + return Mono.empty(); + } + else { + logger.warn("Ignore as a Prompt with name '{}' not found", promptName); + } + return Mono.empty(); + }); + } + + /** + * Notifies clients that the list of available prompts has changed. + * @return A Mono that completes when all clients have been notified + */ + public Mono notifyPromptsListChanged() { + return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_PROMPTS_LIST_CHANGED, null); + } + + private McpRequestHandler promptsListRequestHandler() { + return (exchange, params) -> { + // TODO: Implement pagination + // McpSchema.PaginatedRequest request = objectMapper.convertValue(params, + // new TypeReference() { + // }); + + var promptList = this.prompts.values() + .stream() + .map(McpServerFeatures.AsyncPromptSpecification::prompt) + .toList(); + + return Mono.just(new McpSchema.ListPromptsResult(promptList, null)); + }; + } + + private McpRequestHandler promptsGetRequestHandler() { + return (exchange, params) -> { + McpSchema.GetPromptRequest promptRequest = jsonMapper.convertValue(params, + new TypeRef() { + }); + + // Implement prompt retrieval logic here + McpServerFeatures.AsyncPromptSpecification specification = this.prompts.get(promptRequest.name()); + + if (specification == null) { + return Mono.error(McpError.builder(ErrorCodes.INVALID_PARAMS) + .message("Invalid prompt name") + .data("Prompt not found: " + promptRequest.name()) + .build()); + } + + return Mono.defer(() -> specification.promptHandler().apply(exchange, promptRequest)); + }; + } + + // --------------------------------------- + // Logging Management + // --------------------------------------- + + /** + * This implementation would, incorrectly, broadcast the logging message to all + * connected clients, using a single minLoggingLevel for all of them. Similar to the + * sampling and roots, the logging level should be set per client session and use the + * ServerExchange to send the logging message to the right client. + * @param loggingMessageNotification The logging message to send + * @return A Mono that completes when the notification has been sent + * @deprecated Use + * {@link McpAsyncServerExchange#loggingNotification(LoggingMessageNotification)} + * instead. + */ + @Deprecated + public Mono loggingNotification(LoggingMessageNotification loggingMessageNotification) { + + if (loggingMessageNotification == null) { + return Mono.error(new McpError("Logging message must not be null")); + } + + if (loggingMessageNotification.level().level() < minLoggingLevel.level()) { + return Mono.empty(); + } + + return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_MESSAGE, + loggingMessageNotification); + } + + private McpRequestHandler setLoggerRequestHandler() { + return (exchange, params) -> { + return Mono.defer(() -> { + + SetLevelRequest newMinLoggingLevel = jsonMapper.convertValue(params, new TypeRef() { + }); + + exchange.setMinLoggingLevel(newMinLoggingLevel.level()); + + // FIXME: this field is deprecated and should be removed together + // with the broadcasting loggingNotification. + this.minLoggingLevel = newMinLoggingLevel.level(); + + return Mono.just(Map.of()); + }); + }; + } + + private static final Mono EMPTY_COMPLETION_RESULT = Mono + .just(new McpSchema.CompleteResult(new CompleteCompletion(List.of(), 0, false))); + + private McpRequestHandler completionCompleteRequestHandler() { + return (exchange, params) -> { + + McpSchema.CompleteRequest request = parseCompletionParams(params); + + if (request.ref() == null) { + return Mono.error( + McpError.builder(ErrorCodes.INVALID_PARAMS).message("Completion ref must not be null").build()); + } + + if (request.ref().type() == null) { + return Mono.error(McpError.builder(ErrorCodes.INVALID_PARAMS) + .message("Completion ref type must not be null") + .build()); + } + + String type = request.ref().type(); + + String argumentName = request.argument().name(); + + // Check if valid a Prompt exists for this completion request + if (type.equals(PromptReference.TYPE) + && request.ref() instanceof McpSchema.PromptReference promptReference) { + + McpServerFeatures.AsyncPromptSpecification promptSpec = this.prompts.get(promptReference.name()); + if (promptSpec == null) { + return Mono.error(McpError.builder(ErrorCodes.INVALID_PARAMS) + .message("Prompt not found: " + promptReference.name()) + .build()); + } + if (!promptSpec.prompt() + .arguments() + .stream() + .filter(arg -> arg.name().equals(argumentName)) + .findFirst() + .isPresent()) { + + logger.warn("Argument not found: {} in prompt: {}", argumentName, promptReference.name()); + + return EMPTY_COMPLETION_RESULT; + } + } + + // Check if valid Resource or ResourceTemplate exists for this completion + // request + if (type.equals(ResourceReference.TYPE) + && request.ref() instanceof McpSchema.ResourceReference resourceReference) { + + var uriTemplateManager = uriTemplateManagerFactory.create(resourceReference.uri()); + + if (!uriTemplateManager.isUriTemplate(resourceReference.uri())) { + // Attempting to autocomplete a fixed resource URI is not an error in + // the spec (but probably should be). + return EMPTY_COMPLETION_RESULT; + } + + McpServerFeatures.AsyncResourceSpecification resourceSpec = this + .findResourceSpecification(resourceReference.uri()) + .orElse(null); + + if (resourceSpec != null) { + if (!uriTemplateManagerFactory.create(resourceSpec.resource().uri()) + .getVariableNames() + .contains(argumentName)) { + + return Mono.error(McpError.builder(ErrorCodes.INVALID_PARAMS) + .message("Argument not found: " + argumentName + " in resource: " + resourceReference.uri()) + .build()); + } + } + else { + var templateSpec = this.findResourceTemplateSpecification(resourceReference.uri()).orElse(null); + if (templateSpec != null) { + + if (!uriTemplateManagerFactory.create(templateSpec.resourceTemplate().uriTemplate()) + .getVariableNames() + .contains(argumentName)) { + + return Mono.error(McpError.builder(ErrorCodes.INVALID_PARAMS) + .message("Argument not found: " + argumentName + " in resource template: " + + resourceReference.uri()) + .build()); + } + } + else { + return Mono.error(RESOURCE_NOT_FOUND.apply(resourceReference.uri())); + } + } + } + + // Handle the completion request using the registered handler + // for the given reference. + McpServerFeatures.AsyncCompletionSpecification specification = this.completions.get(request.ref()); + + if (specification == null) { + return Mono.error(McpError.builder(ErrorCodes.INVALID_PARAMS) + .message("AsyncCompletionSpecification not found: " + request.ref()) + .build()); + } + + return Mono.defer(() -> specification.completionHandler().apply(exchange, request)); + }; + } + + /** + * Parses the raw JSON-RPC request parameters into a {@link McpSchema.CompleteRequest} + * object. + *

+ * This method manually extracts the `ref` and `argument` fields from the input map, + * determines the correct reference type (either prompt or resource), and constructs a + * fully-typed {@code CompleteRequest} instance. + * @param object the raw request parameters, expected to be a Map containing "ref" and + * "argument" entries. + * @return a {@link McpSchema.CompleteRequest} representing the structured completion + * request. + * @throws IllegalArgumentException if the "ref" type is not recognized. + */ + @SuppressWarnings("unchecked") + private McpSchema.CompleteRequest parseCompletionParams(Object object) { + Map params = (Map) object; + Map refMap = (Map) params.get("ref"); + Map argMap = (Map) params.get("argument"); + Map contextMap = (Map) params.get("context"); + Map meta = (Map) params.get("_meta"); + + String refType = (String) refMap.get("type"); + + McpSchema.CompleteReference ref = switch (refType) { + case PromptReference.TYPE -> new McpSchema.PromptReference(refType, (String) refMap.get("name"), + refMap.get("title") != null ? (String) refMap.get("title") : null); + case ResourceReference.TYPE -> new McpSchema.ResourceReference(refType, (String) refMap.get("uri")); + default -> throw new IllegalArgumentException("Invalid ref type: " + refType); + }; + + String argName = (String) argMap.get("name"); + String argValue = (String) argMap.get("value"); + McpSchema.CompleteRequest.CompleteArgument argument = new McpSchema.CompleteRequest.CompleteArgument(argName, + argValue); + + McpSchema.CompleteRequest.CompleteContext context = null; + if (contextMap != null) { + Map arguments = (Map) contextMap.get("arguments"); + context = new McpSchema.CompleteRequest.CompleteContext(arguments); + } + + return new McpSchema.CompleteRequest(ref, argument, meta, context); + } + + /** + * This method is package-private and used for test only. Should not be called by user + * code. + * @param protocolVersions the Client supported protocol versions. + */ + void setProtocolVersions(List protocolVersions) { + this.protocolVersions = protocolVersions; + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java similarity index 59% rename from mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java rename to mcp-core/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java index 2fd95a10d..a15c58cd5 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java @@ -4,12 +4,17 @@ package io.modelcontextprotocol.server; -import com.fasterxml.jackson.core.type.TypeReference; +import io.modelcontextprotocol.common.McpTransportContext; +import java.util.ArrayList; +import java.util.Collections; + +import io.modelcontextprotocol.json.TypeRef; import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpLoggableSession; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; -import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.spec.McpSession; import io.modelcontextprotocol.util.Assert; import reactor.core.publisher.Mono; @@ -22,21 +27,26 @@ */ public class McpAsyncServerExchange { - private final McpServerSession session; + private final String sessionId; + + private final McpLoggableSession session; private final McpSchema.ClientCapabilities clientCapabilities; private final McpSchema.Implementation clientInfo; - private volatile LoggingLevel minLoggingLevel = LoggingLevel.INFO; + private final McpTransportContext transportContext; - private static final TypeReference CREATE_MESSAGE_RESULT_TYPE_REF = new TypeReference<>() { + private static final TypeRef CREATE_MESSAGE_RESULT_TYPE_REF = new TypeRef<>() { }; - private static final TypeReference LIST_ROOTS_RESULT_TYPE_REF = new TypeReference<>() { + private static final TypeRef LIST_ROOTS_RESULT_TYPE_REF = new TypeRef<>() { }; - private static final TypeReference ELICITATION_RESULT_TYPE_REF = new TypeReference<>() { + private static final TypeRef ELICITATION_RESULT_TYPE_REF = new TypeRef<>() { + }; + + public static final TypeRef OBJECT_TYPE_REF = new TypeRef<>() { }; /** @@ -45,12 +55,39 @@ public class McpAsyncServerExchange { * @param clientCapabilities The client capabilities that define the supported * features and functionality. * @param clientInfo The client implementation information. + * @deprecated Use + * {@link #McpAsyncServerExchange(String, McpLoggableSession, McpSchema.ClientCapabilities, McpSchema.Implementation, McpTransportContext)} */ - public McpAsyncServerExchange(McpServerSession session, McpSchema.ClientCapabilities clientCapabilities, + @Deprecated + public McpAsyncServerExchange(McpSession session, McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo) { + this.sessionId = null; + if (!(session instanceof McpLoggableSession)) { + throw new IllegalArgumentException("Expecting session to be a McpLoggableSession instance"); + } + this.session = (McpLoggableSession) session; + this.clientCapabilities = clientCapabilities; + this.clientInfo = clientInfo; + this.transportContext = McpTransportContext.EMPTY; + } + + /** + * Create a new asynchronous exchange with the client. + * @param session The server session representing a 1-1 interaction. + * @param clientCapabilities The client capabilities that define the supported + * features and functionality. + * @param clientInfo The client implementation information. + * @param transportContext context associated with the client as extracted from the + * transport + */ + public McpAsyncServerExchange(String sessionId, McpLoggableSession session, + McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo, + McpTransportContext transportContext) { + this.sessionId = sessionId; this.session = session; this.clientCapabilities = clientCapabilities; this.clientInfo = clientInfo; + this.transportContext = transportContext; } /** @@ -69,6 +106,24 @@ public McpSchema.Implementation getClientInfo() { return this.clientInfo; } + /** + * Provides the {@link McpTransportContext} associated with the transport layer. For + * HTTP transports it can contain the metadata associated with the HTTP request that + * triggered the processing. + * @return the transport context object + */ + public McpTransportContext transportContext() { + return this.transportContext; + } + + /** + * Provides the Session ID. + * @return session ID string + */ + public String sessionId() { + return this.sessionId; + } + /** * Create a new message using the sampling capabilities of the client. The Model * Context Protocol (MCP) provides a standardized way for servers to request LLM @@ -126,7 +181,19 @@ public Mono createElicitation(McpSchema.ElicitRequest el * @return A Mono that emits the list of roots result. */ public Mono listRoots() { - return this.listRoots(null); + + // @formatter:off + return this.listRoots(McpSchema.FIRST_PAGE) + .expand(result -> (result.nextCursor() != null) ? + this.listRoots(result.nextCursor()) : Mono.empty()) + .reduce(new McpSchema.ListRootsResult(new ArrayList<>(), null), + (allRootsResult, result) -> { + allRootsResult.roots().addAll(result.roots()); + return allRootsResult; + }) + .map(result -> new McpSchema.ListRootsResult(Collections.unmodifiableList(result.roots()), + result.nextCursor())); + // @formatter:on } /** @@ -152,13 +219,34 @@ public Mono loggingNotification(LoggingMessageNotification loggingMessageN } return Mono.defer(() -> { - if (this.isNotificationForLevelAllowed(loggingMessageNotification.level())) { + if (this.session.isNotificationForLevelAllowed(loggingMessageNotification.level())) { return this.session.sendNotification(McpSchema.METHOD_NOTIFICATION_MESSAGE, loggingMessageNotification); } return Mono.empty(); }); } + /** + * Sends a notification to the client that the current progress status has changed for + * long-running operations. + * @param progressNotification The progress notification to send + * @return A Mono that completes when the notification has been sent + */ + public Mono progressNotification(McpSchema.ProgressNotification progressNotification) { + if (progressNotification == null) { + return Mono.error(new McpError("Progress notification must not be null")); + } + return this.session.sendNotification(McpSchema.METHOD_NOTIFICATION_PROGRESS, progressNotification); + } + + /** + * Sends a ping request to the client. + * @return A Mono that completes with clients's ping response + */ + public Mono ping() { + return this.session.sendRequest(McpSchema.METHOD_PING, null, OBJECT_TYPE_REF); + } + /** * Set the minimum logging level for the client. Messages below this level will be * filtered out. @@ -166,11 +254,7 @@ public Mono loggingNotification(LoggingMessageNotification loggingMessageN */ void setMinLoggingLevel(LoggingLevel minLoggingLevel) { Assert.notNull(minLoggingLevel, "minLoggingLevel must not be null"); - this.minLoggingLevel = minLoggingLevel; - } - - private boolean isNotificationForLevelAllowed(LoggingLevel loggingLevel) { - return loggingLevel.level() >= this.minLoggingLevel.level(); + this.session.setMinLoggingLevel(minLoggingLevel); } } diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpInitRequestHandler.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpInitRequestHandler.java new file mode 100644 index 000000000..13ff45a54 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpInitRequestHandler.java @@ -0,0 +1,22 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import io.modelcontextprotocol.spec.McpSchema; +import reactor.core.publisher.Mono; + +/** + * Request handler for the initialization request. + */ +public interface McpInitRequestHandler { + + /** + * Handles the initialization request. + * @param initializeRequest the initialization request by the client + * @return a Mono that will emit the result of the initialization + */ + Mono handle(McpSchema.InitializeRequest initializeRequest); + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpNotificationHandler.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpNotificationHandler.java new file mode 100644 index 000000000..6b1061c03 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpNotificationHandler.java @@ -0,0 +1,19 @@ +package io.modelcontextprotocol.server; + +import reactor.core.publisher.Mono; + +/** + * A handler for client-initiated notifications. + */ +public interface McpNotificationHandler { + + /** + * Handles a notification from the client. + * @param exchange the exchange associated with the client that allows calling back to + * the connected client or inspecting its capabilities. + * @param params the parameters of the notification. + * @return a Mono that completes once the notification is handled. + */ + Mono handle(McpAsyncServerExchange exchange, Object params); + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpRequestHandler.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpRequestHandler.java new file mode 100644 index 000000000..c9d70ad04 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpRequestHandler.java @@ -0,0 +1,22 @@ +package io.modelcontextprotocol.server; + +import reactor.core.publisher.Mono; + +/** + * A handler for client-initiated requests. + * + * @param the type of the response that is expected as a result of handling the + * request. + */ +public interface McpRequestHandler { + + /** + * Handles a request from the client. + * @param exchange the exchange associated with the client that allows calling back to + * the connected client or inspecting its capabilities. + * @param params the parameters of the request. + * @return a Mono that will emit the response to the request. + */ + Mono handle(McpAsyncServerExchange exchange, Object params); + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpServer.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpServer.java new file mode 100644 index 000000000..fe3125271 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpServer.java @@ -0,0 +1,2362 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import io.modelcontextprotocol.common.McpTransportContext; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.BiConsumer; +import java.util.function.BiFunction; + +import io.modelcontextprotocol.json.McpJsonMapper; + +import io.modelcontextprotocol.json.schema.JsonSchemaValidator; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.ResourceTemplate; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.spec.McpStatelessServerTransport; +import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; +import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.DefaultMcpUriTemplateManagerFactory; +import io.modelcontextprotocol.util.McpUriTemplateManagerFactory; +import reactor.core.publisher.Mono; + +/** + * Factory class for creating Model Context Protocol (MCP) servers. MCP servers expose + * tools, resources, and prompts to AI models through a standardized interface. + * + *

+ * This class serves as the main entry point for implementing the server-side of the MCP + * specification. The server's responsibilities include: + *

    + *
  • Exposing tools that models can invoke to perform actions + *
  • Providing access to resources that give models context + *
  • Managing prompt templates for structured model interactions + *
  • Handling client connections and requests + *
  • Implementing capability negotiation + *
+ * + *

+ * Thread Safety: Both synchronous and asynchronous server implementations are + * thread-safe. The synchronous server processes requests sequentially, while the + * asynchronous server can handle concurrent requests safely through its reactive + * programming model. + * + *

+ * Error Handling: The server implementations provide robust error handling through the + * McpError class. Errors are properly propagated to clients while maintaining the + * server's stability. Server implementations should use appropriate error codes and + * provide meaningful error messages to help diagnose issues. + * + *

+ * The class provides factory methods to create either: + *

    + *
  • {@link McpAsyncServer} for non-blocking operations with reactive responses + *
  • {@link McpSyncServer} for blocking operations with direct responses + *
+ * + *

+ * Example of creating a basic synchronous server:

{@code
+ * McpServer.sync(transportProvider)
+ *     .serverInfo("my-server", "1.0.0")
+ *     .tool(Tool.builder().name("calculator").title("Performs calculations").inputSchema(schema).build(),
+ *           (exchange, args) -> CallToolResult.builder()
+ *                   .content(List.of(new McpSchema.TextContent("Result: " + calculate(args))))
+ *                   .isError(false)
+ *                   .build())
+ *     .build();
+ * }
+ * + * Example of creating a basic asynchronous server:
{@code
+ * McpServer.async(transportProvider)
+ *     .serverInfo("my-server", "1.0.0")
+ *     .tool(Tool.builder().name("calculator").title("Performs calculations").inputSchema(schema).build(),
+ *           (exchange, args) -> Mono.fromSupplier(() -> calculate(args))
+ *               .map(result -> CallToolResult.builder()
+ *                   .content(List.of(new McpSchema.TextContent("Result: " + result)))
+ *                   .isError(false)
+ *                   .build()))
+ *     .build();
+ * }
+ * + *

+ * Example with comprehensive asynchronous configuration:

{@code
+ * McpServer.async(transportProvider)
+ *     .serverInfo("advanced-server", "2.0.0")
+ *     .capabilities(new ServerCapabilities(...))
+ *     // Register tools
+ *     .tools(
+ *         McpServerFeatures.AsyncToolSpecification.builder()
+ * 			.tool(calculatorTool)
+ *   	    .callTool((exchange, args) -> Mono.fromSupplier(() -> calculate(args.arguments()))
+ *                 .map(result -> CallToolResult.builder()
+ *                   .content(List.of(new McpSchema.TextContent("Result: " + result)))
+ *                   .isError(false)
+ *                   .build()))
+ *.         .build(),
+ *         McpServerFeatures.AsyncToolSpecification.builder()
+ * 	        .tool((weatherTool)
+ *          .callTool((exchange, args) -> Mono.fromSupplier(() -> getWeather(args.arguments()))
+ *                 .map(result -> CallToolResult.builder()
+ *                   .content(List.of(new McpSchema.TextContent("Weather: " + result)))
+ *                   .isError(false)
+ *                   .build()))
+ *          .build()
+ *     )
+ *     // Register resources
+ *     .resources(
+ *         new McpServerFeatures.AsyncResourceSpecification(fileResource,
+ *             (exchange, req) -> Mono.fromSupplier(() -> readFile(req))
+ *                 .map(ReadResourceResult::new)),
+ *         new McpServerFeatures.AsyncResourceSpecification(dbResource,
+ *             (exchange, req) -> Mono.fromSupplier(() -> queryDb(req))
+ *                 .map(ReadResourceResult::new))
+ *     )
+ *     // Add resource templates
+ *     .resourceTemplates(
+ *         new ResourceTemplate("file://{path}", "Access files"),
+ *         new ResourceTemplate("db://{table}", "Access database")
+ *     )
+ *     // Register prompts
+ *     .prompts(
+ *         new McpServerFeatures.AsyncPromptSpecification(analysisPrompt,
+ *             (exchange, req) -> Mono.fromSupplier(() -> generateAnalysisPrompt(req))
+ *                 .map(GetPromptResult::new)),
+ *         new McpServerFeatures.AsyncPromptRegistration(summaryPrompt,
+ *             (exchange, req) -> Mono.fromSupplier(() -> generateSummaryPrompt(req))
+ *                 .map(GetPromptResult::new))
+ *     )
+ *     .build();
+ * }
+ * + * @author Christian Tzolov + * @author Dariusz Jędrzejczyk + * @author Jihoon Kim + * @see McpAsyncServer + * @see McpSyncServer + * @see McpServerTransportProvider + */ +public interface McpServer { + + McpSchema.Implementation DEFAULT_SERVER_INFO = new McpSchema.Implementation("Java SDK MCP Server", "0.15.0"); + + /** + * Starts building a synchronous MCP server that provides blocking operations. + * Synchronous servers block the current Thread's execution upon each request before + * giving the control back to the caller, making them simpler to implement but + * potentially less scalable for concurrent operations. + * @param transportProvider The transport layer implementation for MCP communication. + * @return A new instance of {@link SyncSpecification} for configuring the server. + */ + static SingleSessionSyncSpecification sync(McpServerTransportProvider transportProvider) { + return new SingleSessionSyncSpecification(transportProvider); + } + + /** + * Starts building an asynchronous MCP server that provides non-blocking operations. + * Asynchronous servers can handle multiple requests concurrently on a single Thread + * using a functional paradigm with non-blocking server transports, making them more + * scalable for high-concurrency scenarios but more complex to implement. + * @param transportProvider The transport layer implementation for MCP communication. + * @return A new instance of {@link AsyncSpecification} for configuring the server. + */ + static AsyncSpecification async(McpServerTransportProvider transportProvider) { + return new SingleSessionAsyncSpecification(transportProvider); + } + + /** + * Starts building a synchronous MCP server that provides blocking operations. + * Synchronous servers block the current Thread's execution upon each request before + * giving the control back to the caller, making them simpler to implement but + * potentially less scalable for concurrent operations. + * @param transportProvider The transport layer implementation for MCP communication. + * @return A new instance of {@link SyncSpecification} for configuring the server. + */ + static StreamableSyncSpecification sync(McpStreamableServerTransportProvider transportProvider) { + return new StreamableSyncSpecification(transportProvider); + } + + /** + * Starts building an asynchronous MCP server that provides non-blocking operations. + * Asynchronous servers can handle multiple requests concurrently on a single Thread + * using a functional paradigm with non-blocking server transports, making them more + * scalable for high-concurrency scenarios but more complex to implement. + * @param transportProvider The transport layer implementation for MCP communication. + * @return A new instance of {@link AsyncSpecification} for configuring the server. + */ + static AsyncSpecification async(McpStreamableServerTransportProvider transportProvider) { + return new StreamableServerAsyncSpecification(transportProvider); + } + + /** + * Starts building an asynchronous MCP server that provides non-blocking operations. + * Asynchronous servers can handle multiple requests concurrently on a single Thread + * using a functional paradigm with non-blocking server transports, making them more + * scalable for high-concurrency scenarios but more complex to implement. + * @param transport The transport layer implementation for MCP communication. + * @return A new instance of {@link AsyncSpecification} for configuring the server. + */ + static StatelessAsyncSpecification async(McpStatelessServerTransport transport) { + return new StatelessAsyncSpecification(transport); + } + + /** + * Starts building a synchronous MCP server that provides blocking operations. + * Synchronous servers block the current Thread's execution upon each request before + * giving the control back to the caller, making them simpler to implement but + * potentially less scalable for concurrent operations. + * @param transport The transport layer implementation for MCP communication. + * @return A new instance of {@link SyncSpecification} for configuring the server. + */ + static StatelessSyncSpecification sync(McpStatelessServerTransport transport) { + return new StatelessSyncSpecification(transport); + } + + class SingleSessionAsyncSpecification extends AsyncSpecification { + + private final McpServerTransportProvider transportProvider; + + private SingleSessionAsyncSpecification(McpServerTransportProvider transportProvider) { + Assert.notNull(transportProvider, "Transport provider must not be null"); + this.transportProvider = transportProvider; + } + + /** + * Builds an asynchronous MCP server that provides non-blocking operations. + * @return A new instance of {@link McpAsyncServer} configured with this builder's + * settings. + */ + @Override + public McpAsyncServer build() { + var features = new McpServerFeatures.Async(this.serverInfo, this.serverCapabilities, this.tools, + this.resources, this.resourceTemplates, this.prompts, this.completions, this.rootsChangeHandlers, + this.instructions); + + var jsonSchemaValidator = (this.jsonSchemaValidator != null) ? this.jsonSchemaValidator + : JsonSchemaValidator.getDefault(); + + return new McpAsyncServer(transportProvider, jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, + features, requestTimeout, uriTemplateManagerFactory, jsonSchemaValidator); + } + + } + + class StreamableServerAsyncSpecification extends AsyncSpecification { + + private final McpStreamableServerTransportProvider transportProvider; + + public StreamableServerAsyncSpecification(McpStreamableServerTransportProvider transportProvider) { + this.transportProvider = transportProvider; + } + + /** + * Builds an asynchronous MCP server that provides non-blocking operations. + * @return A new instance of {@link McpAsyncServer} configured with this builder's + * settings. + */ + @Override + public McpAsyncServer build() { + var features = new McpServerFeatures.Async(this.serverInfo, this.serverCapabilities, this.tools, + this.resources, this.resourceTemplates, this.prompts, this.completions, this.rootsChangeHandlers, + this.instructions); + var jsonSchemaValidator = this.jsonSchemaValidator != null ? this.jsonSchemaValidator + : JsonSchemaValidator.getDefault(); + return new McpAsyncServer(transportProvider, jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, + features, requestTimeout, uriTemplateManagerFactory, jsonSchemaValidator); + } + + } + + /** + * Asynchronous server specification. + */ + abstract class AsyncSpecification> { + + McpUriTemplateManagerFactory uriTemplateManagerFactory = new DefaultMcpUriTemplateManagerFactory(); + + McpJsonMapper jsonMapper; + + McpSchema.Implementation serverInfo = DEFAULT_SERVER_INFO; + + McpSchema.ServerCapabilities serverCapabilities; + + JsonSchemaValidator jsonSchemaValidator; + + String instructions; + + /** + * The Model Context Protocol (MCP) allows servers to expose tools that can be + * invoked by language models. Tools enable models to interact with external + * systems, such as querying databases, calling APIs, or performing computations. + * Each tool is uniquely identified by a name and includes metadata describing its + * schema. + */ + final List tools = new ArrayList<>(); + + /** + * The Model Context Protocol (MCP) provides a standardized way for servers to + * expose resources to clients. Resources allow servers to share data that + * provides context to language models, such as files, database schemas, or + * application-specific information. Each resource is uniquely identified by a + * URI. + */ + final Map resources = new HashMap<>(); + + /** + * The Model Context Protocol (MCP) provides a standardized way for servers to + * expose resource templates to clients. Resource templates allow servers to + * define parameterized URIs that clients can use to access dynamic resources. + * Each resource template includes variables that clients can fill in to form + * concrete resource URIs. + */ + final Map resourceTemplates = new HashMap<>(); + + /** + * The Model Context Protocol (MCP) provides a standardized way for servers to + * expose prompt templates to clients. Prompts allow servers to provide structured + * messages and instructions for interacting with language models. Clients can + * discover available prompts, retrieve their contents, and provide arguments to + * customize them. + */ + final Map prompts = new HashMap<>(); + + final Map completions = new HashMap<>(); + + final List, Mono>> rootsChangeHandlers = new ArrayList<>(); + + Duration requestTimeout = Duration.ofHours(10); // Default timeout + + public abstract McpAsyncServer build(); + + /** + * Sets the URI template manager factory to use for creating URI templates. This + * allows for custom URI template parsing and variable extraction. + * @param uriTemplateManagerFactory The factory to use. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if uriTemplateManagerFactory is null + */ + public AsyncSpecification uriTemplateManagerFactory(McpUriTemplateManagerFactory uriTemplateManagerFactory) { + Assert.notNull(uriTemplateManagerFactory, "URI template manager factory must not be null"); + this.uriTemplateManagerFactory = uriTemplateManagerFactory; + return this; + } + + /** + * Sets the duration to wait for server responses before timing out requests. This + * timeout applies to all requests made through the client, including tool calls, + * resource access, and prompt operations. + * @param requestTimeout The duration to wait before timing out requests. Must not + * be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if requestTimeout is null + */ + public AsyncSpecification requestTimeout(Duration requestTimeout) { + Assert.notNull(requestTimeout, "Request timeout must not be null"); + this.requestTimeout = requestTimeout; + return this; + } + + /** + * Sets the server implementation information that will be shared with clients + * during connection initialization. This helps with version compatibility, + * debugging, and server identification. + * @param serverInfo The server implementation details including name and version. + * Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if serverInfo is null + */ + public AsyncSpecification serverInfo(McpSchema.Implementation serverInfo) { + Assert.notNull(serverInfo, "Server info must not be null"); + this.serverInfo = serverInfo; + return this; + } + + /** + * Sets the server implementation information using name and version strings. This + * is a convenience method alternative to + * {@link #serverInfo(McpSchema.Implementation)}. + * @param name The server name. Must not be null or empty. + * @param version The server version. Must not be null or empty. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if name or version is null or empty + * @see #serverInfo(McpSchema.Implementation) + */ + public AsyncSpecification serverInfo(String name, String version) { + Assert.hasText(name, "Name must not be null or empty"); + Assert.hasText(version, "Version must not be null or empty"); + this.serverInfo = new McpSchema.Implementation(name, version); + return this; + } + + /** + * Sets the server instructions that will be shared with clients during connection + * initialization. These instructions provide guidance to the client on how to + * interact with this server. + * @param instructions The instructions text. Can be null or empty. + * @return This builder instance for method chaining + */ + public AsyncSpecification instructions(String instructions) { + this.instructions = instructions; + return this; + } + + /** + * Sets the server capabilities that will be advertised to clients during + * connection initialization. Capabilities define what features the server + * supports, such as: + *
    + *
  • Tool execution + *
  • Resource access + *
  • Prompt handling + *
+ * @param serverCapabilities The server capabilities configuration. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if serverCapabilities is null + */ + public AsyncSpecification capabilities(McpSchema.ServerCapabilities serverCapabilities) { + Assert.notNull(serverCapabilities, "Server capabilities must not be null"); + this.serverCapabilities = serverCapabilities; + return this; + } + + /** + * Adds a single tool with its implementation handler to the server. This is a + * convenience method for registering individual tools without creating a + * {@link McpServerFeatures.AsyncToolSpecification} explicitly. + * + *

+ * Example usage:

{@code
+		 * .tool(
+		 *     Tool.builder().name("calculator").title("Performs calculations").inputSchema(schema).build(),
+		 *     (exchange, args) -> Mono.fromSupplier(() -> calculate(args))
+		 *         .map(result -> CallToolResult.builder()
+		 *                   .content(List.of(new McpSchema.TextContent("Result: " + result)))
+		 *                   .isError(false)
+		 *                   .build()))
+		 * )
+		 * }
+ * @param tool The tool definition including name, description, and schema. Must + * not be null. + * @param handler The function that implements the tool's logic. Must not be null. + * The function's first argument is an {@link McpAsyncServerExchange} upon which + * the server can interact with the connected client. The second argument is the + * map of arguments passed to the tool. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if tool or handler is null + * @deprecated Use {@link #toolCall(McpSchema.Tool, BiFunction)} instead for tool + * calls that require a request object. + */ + @Deprecated + public AsyncSpecification tool(McpSchema.Tool tool, + BiFunction, Mono> handler) { + Assert.notNull(tool, "Tool must not be null"); + Assert.notNull(handler, "Handler must not be null"); + assertNoDuplicateTool(tool.name()); + + this.tools.add(new McpServerFeatures.AsyncToolSpecification(tool, handler)); + + return this; + } + + /** + * Adds a single tool with its implementation handler to the server. This is a + * convenience method for registering individual tools without creating a + * {@link McpServerFeatures.AsyncToolSpecification} explicitly. + * @param tool The tool definition including name, description, and schema. Must + * not be null. + * @param callHandler The function that implements the tool's logic. Must not be + * null. The function's first argument is an {@link McpAsyncServerExchange} upon + * which the server can interact with the connected client. The second argument is + * the {@link McpSchema.CallToolRequest} object containing the tool call + * @return This builder instance for method chaining + * @throws IllegalArgumentException if tool or handler is null + */ + public AsyncSpecification toolCall(McpSchema.Tool tool, + BiFunction> callHandler) { + + Assert.notNull(tool, "Tool must not be null"); + Assert.notNull(callHandler, "Handler must not be null"); + assertNoDuplicateTool(tool.name()); + + this.tools + .add(McpServerFeatures.AsyncToolSpecification.builder().tool(tool).callHandler(callHandler).build()); + + return this; + } + + /** + * Adds multiple tools with their handlers to the server using a List. This method + * is useful when tools are dynamically generated or loaded from a configuration + * source. + * @param toolSpecifications The list of tool specifications to add. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if toolSpecifications is null + * @see #tools(McpServerFeatures.AsyncToolSpecification...) + */ + public AsyncSpecification tools(List toolSpecifications) { + Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); + + for (var tool : toolSpecifications) { + assertNoDuplicateTool(tool.tool().name()); + this.tools.add(tool); + } + + return this; + } + + /** + * Adds multiple tools with their handlers to the server using varargs. This + * method provides a convenient way to register multiple tools inline. + * + *

+ * Example usage:

{@code
+		 * .tools(
+		 *     McpServerFeatures.AsyncToolSpecification.builder().tool(calculatorTool).callTool(calculatorHandler).build(),
+		 *     McpServerFeatures.AsyncToolSpecification.builder().tool(weatherTool).callTool(weatherHandler).build(),
+		 *     McpServerFeatures.AsyncToolSpecification.builder().tool(fileManagerTool).callTool(fileManagerHandler).build()
+		 * )
+		 * }
+ * @param toolSpecifications The tool specifications to add. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if toolSpecifications is null + */ + public AsyncSpecification tools(McpServerFeatures.AsyncToolSpecification... toolSpecifications) { + Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); + + for (McpServerFeatures.AsyncToolSpecification tool : toolSpecifications) { + assertNoDuplicateTool(tool.tool().name()); + this.tools.add(tool); + } + return this; + } + + private void assertNoDuplicateTool(String toolName) { + if (this.tools.stream().anyMatch(toolSpec -> toolSpec.tool().name().equals(toolName))) { + throw new IllegalArgumentException("Tool with name '" + toolName + "' is already registered."); + } + } + + /** + * Registers multiple resources with their handlers using a Map. This method is + * useful when resources are dynamically generated or loaded from a configuration + * source. + * @param resourceSpecifications Map of resource name to specification. Must not + * be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceSpecifications is null + * @see #resources(McpServerFeatures.AsyncResourceSpecification...) + */ + public AsyncSpecification resources( + Map resourceSpecifications) { + Assert.notNull(resourceSpecifications, "Resource handlers map must not be null"); + this.resources.putAll(resourceSpecifications); + return this; + } + + /** + * Registers multiple resources with their handlers using a List. This method is + * useful when resources need to be added in bulk from a collection. + * @param resourceSpecifications List of resource specifications. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceSpecifications is null + * @see #resources(McpServerFeatures.AsyncResourceSpecification...) + */ + public AsyncSpecification resources( + List resourceSpecifications) { + Assert.notNull(resourceSpecifications, "Resource handlers list must not be null"); + for (McpServerFeatures.AsyncResourceSpecification resource : resourceSpecifications) { + this.resources.put(resource.resource().uri(), resource); + } + return this; + } + + /** + * Registers multiple resources with their handlers using varargs. This method + * provides a convenient way to register multiple resources inline. + * + *

+ * Example usage:

{@code
+		 * .resources(
+		 *     new McpServerFeatures.AsyncResourceSpecification(fileResource, fileHandler),
+		 *     new McpServerFeatures.AsyncResourceSpecification(dbResource, dbHandler),
+		 *     new McpServerFeatures.AsyncResourceSpecification(apiResource, apiHandler)
+		 * )
+		 * }
+ * @param resourceSpecifications The resource specifications to add. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceSpecifications is null + */ + public AsyncSpecification resources(McpServerFeatures.AsyncResourceSpecification... resourceSpecifications) { + Assert.notNull(resourceSpecifications, "Resource handlers list must not be null"); + for (McpServerFeatures.AsyncResourceSpecification resource : resourceSpecifications) { + this.resources.put(resource.resource().uri(), resource); + } + return this; + } + + /** + * Registers multiple resource templates with their specifications using a List. + * This method is useful when resource templates need to be added in bulk from a + * collection. + * @param resourceTemplates Map of template URI to specification. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceTemplates is null. + */ + public AsyncSpecification resourceTemplates( + List resourceTemplates) { + Assert.notNull(resourceTemplates, "Resource templates must not be null"); + for (var resourceTemplate : resourceTemplates) { + this.resourceTemplates.put(resourceTemplate.resourceTemplate().uriTemplate(), resourceTemplate); + } + return this; + } + + /** + * Registers multiple resource templates with their specifications using a List. + * This method is useful when resource templates need to be added in bulk from a + * collection. + * @param resourceTemplates List of template URI to specification. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceTemplates is null. + * @see #resourceTemplates(List) + */ + public AsyncSpecification resourceTemplates( + McpServerFeatures.AsyncResourceTemplateSpecification... resourceTemplates) { + Assert.notNull(resourceTemplates, "Resource templates must not be null"); + for (McpServerFeatures.AsyncResourceTemplateSpecification resource : resourceTemplates) { + this.resourceTemplates.put(resource.resourceTemplate().uriTemplate(), resource); + } + return this; + } + + /** + * Registers multiple prompts with their handlers using a Map. This method is + * useful when prompts are dynamically generated or loaded from a configuration + * source. + * + *

+ * Example usage:

{@code
+		 * .prompts(Map.of("analysis", new McpServerFeatures.AsyncPromptSpecification(
+		 *     new Prompt("analysis", "Code analysis template"),
+		 *     request -> Mono.fromSupplier(() -> generateAnalysisPrompt(request))
+		 *         .map(GetPromptResult::new)
+		 * )));
+		 * }
+ * @param prompts Map of prompt name to specification. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if prompts is null + */ + public AsyncSpecification prompts(Map prompts) { + Assert.notNull(prompts, "Prompts map must not be null"); + this.prompts.putAll(prompts); + return this; + } + + /** + * Registers multiple prompts with their handlers using a List. This method is + * useful when prompts need to be added in bulk from a collection. + * @param prompts List of prompt specifications. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if prompts is null + * @see #prompts(McpServerFeatures.AsyncPromptSpecification...) + */ + public AsyncSpecification prompts(List prompts) { + Assert.notNull(prompts, "Prompts list must not be null"); + for (McpServerFeatures.AsyncPromptSpecification prompt : prompts) { + this.prompts.put(prompt.prompt().name(), prompt); + } + return this; + } + + /** + * Registers multiple prompts with their handlers using varargs. This method + * provides a convenient way to register multiple prompts inline. + * + *

+ * Example usage:

{@code
+		 * .prompts(
+		 *     new McpServerFeatures.AsyncPromptSpecification(analysisPrompt, analysisHandler),
+		 *     new McpServerFeatures.AsyncPromptSpecification(summaryPrompt, summaryHandler),
+		 *     new McpServerFeatures.AsyncPromptSpecification(reviewPrompt, reviewHandler)
+		 * )
+		 * }
+ * @param prompts The prompt specifications to add. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if prompts is null + */ + public AsyncSpecification prompts(McpServerFeatures.AsyncPromptSpecification... prompts) { + Assert.notNull(prompts, "Prompts list must not be null"); + for (McpServerFeatures.AsyncPromptSpecification prompt : prompts) { + this.prompts.put(prompt.prompt().name(), prompt); + } + return this; + } + + /** + * Registers multiple completions with their handlers using a List. This method is + * useful when completions need to be added in bulk from a collection. + * @param completions List of completion specifications. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if completions is null + */ + public AsyncSpecification completions(List completions) { + Assert.notNull(completions, "Completions list must not be null"); + for (McpServerFeatures.AsyncCompletionSpecification completion : completions) { + this.completions.put(completion.referenceKey(), completion); + } + return this; + } + + /** + * Registers multiple completions with their handlers using varargs. This method + * is useful when completions are defined inline and added directly. + * @param completions Array of completion specifications. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if completions is null + */ + public AsyncSpecification completions(McpServerFeatures.AsyncCompletionSpecification... completions) { + Assert.notNull(completions, "Completions list must not be null"); + for (McpServerFeatures.AsyncCompletionSpecification completion : completions) { + this.completions.put(completion.referenceKey(), completion); + } + return this; + } + + /** + * Registers a consumer that will be notified when the list of roots changes. This + * is useful for updating resource availability dynamically, such as when new + * files are added or removed. + * @param handler The handler to register. Must not be null. The function's first + * argument is an {@link McpAsyncServerExchange} upon which the server can + * interact with the connected client. The second argument is the list of roots. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if consumer is null + */ + public AsyncSpecification rootsChangeHandler( + BiFunction, Mono> handler) { + Assert.notNull(handler, "Consumer must not be null"); + this.rootsChangeHandlers.add(handler); + return this; + } + + /** + * Registers multiple consumers that will be notified when the list of roots + * changes. This method is useful when multiple consumers need to be registered at + * once. + * @param handlers The list of handlers to register. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if consumers is null + * @see #rootsChangeHandler(BiFunction) + */ + public AsyncSpecification rootsChangeHandlers( + List, Mono>> handlers) { + Assert.notNull(handlers, "Handlers list must not be null"); + this.rootsChangeHandlers.addAll(handlers); + return this; + } + + /** + * Registers multiple consumers that will be notified when the list of roots + * changes using varargs. This method provides a convenient way to register + * multiple consumers inline. + * @param handlers The handlers to register. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if consumers is null + * @see #rootsChangeHandlers(List) + */ + public AsyncSpecification rootsChangeHandlers( + @SuppressWarnings("unchecked") BiFunction, Mono>... handlers) { + Assert.notNull(handlers, "Handlers list must not be null"); + return this.rootsChangeHandlers(Arrays.asList(handlers)); + } + + /** + * Sets the JsonMapper to use for serializing and deserializing JSON messages. + * @param jsonMapper the mapper to use. Must not be null. + * @return This builder instance for method chaining. + * @throws IllegalArgumentException if jsonMapper is null + */ + public AsyncSpecification jsonMapper(McpJsonMapper jsonMapper) { + Assert.notNull(jsonMapper, "JsonMapper must not be null"); + this.jsonMapper = jsonMapper; + return this; + } + + /** + * Sets the JSON schema validator to use for validating tool and resource schemas. + * This ensures that the server's tools and resources conform to the expected + * schema definitions. + * @param jsonSchemaValidator The validator to use. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if jsonSchemaValidator is null + */ + public AsyncSpecification jsonSchemaValidator(JsonSchemaValidator jsonSchemaValidator) { + Assert.notNull(jsonSchemaValidator, "JsonSchemaValidator must not be null"); + this.jsonSchemaValidator = jsonSchemaValidator; + return this; + } + + } + + class SingleSessionSyncSpecification extends SyncSpecification { + + private final McpServerTransportProvider transportProvider; + + private SingleSessionSyncSpecification(McpServerTransportProvider transportProvider) { + Assert.notNull(transportProvider, "Transport provider must not be null"); + this.transportProvider = transportProvider; + } + + /** + * Builds a synchronous MCP server that provides blocking operations. + * @return A new instance of {@link McpSyncServer} configured with this builder's + * settings. + */ + @Override + public McpSyncServer build() { + McpServerFeatures.Sync syncFeatures = new McpServerFeatures.Sync(this.serverInfo, this.serverCapabilities, + this.tools, this.resources, this.resourceTemplates, this.prompts, this.completions, + this.rootsChangeHandlers, this.instructions); + McpServerFeatures.Async asyncFeatures = McpServerFeatures.Async.fromSync(syncFeatures, + this.immediateExecution); + + var asyncServer = new McpAsyncServer(transportProvider, + jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, asyncFeatures, requestTimeout, + uriTemplateManagerFactory, + jsonSchemaValidator != null ? jsonSchemaValidator : JsonSchemaValidator.getDefault()); + return new McpSyncServer(asyncServer, this.immediateExecution); + } + + } + + class StreamableSyncSpecification extends SyncSpecification { + + private final McpStreamableServerTransportProvider transportProvider; + + private StreamableSyncSpecification(McpStreamableServerTransportProvider transportProvider) { + Assert.notNull(transportProvider, "Transport provider must not be null"); + this.transportProvider = transportProvider; + } + + /** + * Builds a synchronous MCP server that provides blocking operations. + * @return A new instance of {@link McpSyncServer} configured with this builder's + * settings. + */ + @Override + public McpSyncServer build() { + McpServerFeatures.Sync syncFeatures = new McpServerFeatures.Sync(this.serverInfo, this.serverCapabilities, + this.tools, this.resources, this.resourceTemplates, this.prompts, this.completions, + this.rootsChangeHandlers, this.instructions); + McpServerFeatures.Async asyncFeatures = McpServerFeatures.Async.fromSync(syncFeatures, + this.immediateExecution); + var jsonSchemaValidator = this.jsonSchemaValidator != null ? this.jsonSchemaValidator + : JsonSchemaValidator.getDefault(); + var asyncServer = new McpAsyncServer(transportProvider, + jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, asyncFeatures, this.requestTimeout, + this.uriTemplateManagerFactory, jsonSchemaValidator); + return new McpSyncServer(asyncServer, this.immediateExecution); + } + + } + + /** + * Synchronous server specification. + */ + abstract class SyncSpecification> { + + McpUriTemplateManagerFactory uriTemplateManagerFactory = new DefaultMcpUriTemplateManagerFactory(); + + McpJsonMapper jsonMapper; + + McpSchema.Implementation serverInfo = DEFAULT_SERVER_INFO; + + McpSchema.ServerCapabilities serverCapabilities; + + String instructions; + + /** + * The Model Context Protocol (MCP) allows servers to expose tools that can be + * invoked by language models. Tools enable models to interact with external + * systems, such as querying databases, calling APIs, or performing computations. + * Each tool is uniquely identified by a name and includes metadata describing its + * schema. + */ + final List tools = new ArrayList<>(); + + /** + * The Model Context Protocol (MCP) provides a standardized way for servers to + * expose resources to clients. Resources allow servers to share data that + * provides context to language models, such as files, database schemas, or + * application-specific information. Each resource is uniquely identified by a + * URI. + */ + final Map resources = new HashMap<>(); + + /** + * The Model Context Protocol (MCP) provides a standardized way for servers to + * expose resource templates to clients. Resource templates allow servers to + * define parameterized URIs that clients can use to access dynamic resources. + * Each resource template includes variables that clients can fill in to form + * concrete resource URIs. + */ + final Map resourceTemplates = new HashMap<>(); + + JsonSchemaValidator jsonSchemaValidator; + + /** + * The Model Context Protocol (MCP) provides a standardized way for servers to + * expose prompt templates to clients. Prompts allow servers to provide structured + * messages and instructions for interacting with language models. Clients can + * discover available prompts, retrieve their contents, and provide arguments to + * customize them. + */ + final Map prompts = new HashMap<>(); + + final Map completions = new HashMap<>(); + + final List>> rootsChangeHandlers = new ArrayList<>(); + + Duration requestTimeout = Duration.ofSeconds(10); // Default timeout + + boolean immediateExecution = false; + + public abstract McpSyncServer build(); + + /** + * Sets the URI template manager factory to use for creating URI templates. This + * allows for custom URI template parsing and variable extraction. + * @param uriTemplateManagerFactory The factory to use. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if uriTemplateManagerFactory is null + */ + public SyncSpecification uriTemplateManagerFactory(McpUriTemplateManagerFactory uriTemplateManagerFactory) { + Assert.notNull(uriTemplateManagerFactory, "URI template manager factory must not be null"); + this.uriTemplateManagerFactory = uriTemplateManagerFactory; + return this; + } + + /** + * Sets the duration to wait for server responses before timing out requests. This + * timeout applies to all requests made through the client, including tool calls, + * resource access, and prompt operations. + * @param requestTimeout The duration to wait before timing out requests. Must not + * be null. + * @return this builder instance for method chaining + * @throws IllegalArgumentException if requestTimeout is null + */ + public SyncSpecification requestTimeout(Duration requestTimeout) { + Assert.notNull(requestTimeout, "Request timeout must not be null"); + this.requestTimeout = requestTimeout; + return this; + } + + /** + * Sets the server implementation information that will be shared with clients + * during connection initialization. This helps with version compatibility, + * debugging, and server identification. + * @param serverInfo The server implementation details including name and version. + * Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if serverInfo is null + */ + public SyncSpecification serverInfo(McpSchema.Implementation serverInfo) { + Assert.notNull(serverInfo, "Server info must not be null"); + this.serverInfo = serverInfo; + return this; + } + + /** + * Sets the server implementation information using name and version strings. This + * is a convenience method alternative to + * {@link #serverInfo(McpSchema.Implementation)}. + * @param name The server name. Must not be null or empty. + * @param version The server version. Must not be null or empty. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if name or version is null or empty + * @see #serverInfo(McpSchema.Implementation) + */ + public SyncSpecification serverInfo(String name, String version) { + Assert.hasText(name, "Name must not be null or empty"); + Assert.hasText(version, "Version must not be null or empty"); + this.serverInfo = new McpSchema.Implementation(name, version); + return this; + } + + /** + * Sets the server instructions that will be shared with clients during connection + * initialization. These instructions provide guidance to the client on how to + * interact with this server. + * @param instructions The instructions text. Can be null or empty. + * @return This builder instance for method chaining + */ + public SyncSpecification instructions(String instructions) { + this.instructions = instructions; + return this; + } + + /** + * Sets the server capabilities that will be advertised to clients during + * connection initialization. Capabilities define what features the server + * supports, such as: + *
    + *
  • Tool execution + *
  • Resource access + *
  • Prompt handling + *
+ * @param serverCapabilities The server capabilities configuration. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if serverCapabilities is null + */ + public SyncSpecification capabilities(McpSchema.ServerCapabilities serverCapabilities) { + Assert.notNull(serverCapabilities, "Server capabilities must not be null"); + this.serverCapabilities = serverCapabilities; + return this; + } + + /** + * Adds a single tool with its implementation handler to the server. This is a + * convenience method for registering individual tools without creating a + * {@link McpServerFeatures.SyncToolSpecification} explicitly. + * + *

+ * Example usage:

{@code
+		 * .tool(
+		 *     Tool.builder().name("calculator").title("Performs calculations".inputSchema(schema).build(),
+		 *     (exchange, args) -> CallToolResult.builder()
+		 *                   .content(List.of(new McpSchema.TextContent("Result: " + calculate(args))))
+		 *                   .isError(false)
+		 *                   .build())
+		 * )
+		 * }
+ * @param tool The tool definition including name, description, and schema. Must + * not be null. + * @param handler The function that implements the tool's logic. Must not be null. + * The function's first argument is an {@link McpSyncServerExchange} upon which + * the server can interact with the connected client. The second argument is the + * list of arguments passed to the tool. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if tool or handler is null + * @deprecated Use {@link #toolCall(McpSchema.Tool, BiFunction)} instead for tool + * calls that require a request object. + */ + @Deprecated + public SyncSpecification tool(McpSchema.Tool tool, + BiFunction, McpSchema.CallToolResult> handler) { + Assert.notNull(tool, "Tool must not be null"); + Assert.notNull(handler, "Handler must not be null"); + assertNoDuplicateTool(tool.name()); + + this.tools.add(new McpServerFeatures.SyncToolSpecification(tool, handler)); + + return this; + } + + /** + * Adds a single tool with its implementation handler to the server. This is a + * convenience method for registering individual tools without creating a + * {@link McpServerFeatures.SyncToolSpecification} explicitly. + * @param tool The tool definition including name, description, and schema. Must + * not be null. + * @param handler The function that implements the tool's logic. Must not be null. + * The function's first argument is an {@link McpSyncServerExchange} upon which + * the server can interact with the connected client. The second argument is the + * list of arguments passed to the tool. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if tool or handler is null + */ + public SyncSpecification toolCall(McpSchema.Tool tool, + BiFunction handler) { + Assert.notNull(tool, "Tool must not be null"); + Assert.notNull(handler, "Handler must not be null"); + assertNoDuplicateTool(tool.name()); + + this.tools.add(new McpServerFeatures.SyncToolSpecification(tool, null, handler)); + + return this; + } + + /** + * Adds multiple tools with their handlers to the server using a List. This method + * is useful when tools are dynamically generated or loaded from a configuration + * source. + * @param toolSpecifications The list of tool specifications to add. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if toolSpecifications is null + * @see #tools(McpServerFeatures.SyncToolSpecification...) + */ + public SyncSpecification tools(List toolSpecifications) { + Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); + + for (var tool : toolSpecifications) { + String toolName = tool.tool().name(); + assertNoDuplicateTool(toolName); // Check against existing tools + this.tools.add(tool); + } + + return this; + } + + /** + * Adds multiple tools with their handlers to the server using varargs. This + * method provides a convenient way to register multiple tools inline. + * + *

+ * Example usage:

{@code
+		 * .tools(
+		 *     new ToolSpecification(calculatorTool, calculatorHandler),
+		 *     new ToolSpecification(weatherTool, weatherHandler),
+		 *     new ToolSpecification(fileManagerTool, fileManagerHandler)
+		 * )
+		 * }
+ * @param toolSpecifications The tool specifications to add. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if toolSpecifications is null + * @see #tools(List) + */ + public SyncSpecification tools(McpServerFeatures.SyncToolSpecification... toolSpecifications) { + Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); + + for (McpServerFeatures.SyncToolSpecification tool : toolSpecifications) { + assertNoDuplicateTool(tool.tool().name()); + this.tools.add(tool); + } + return this; + } + + private void assertNoDuplicateTool(String toolName) { + if (this.tools.stream().anyMatch(toolSpec -> toolSpec.tool().name().equals(toolName))) { + throw new IllegalArgumentException("Tool with name '" + toolName + "' is already registered."); + } + } + + /** + * Registers multiple resources with their handlers using a Map. This method is + * useful when resources are dynamically generated or loaded from a configuration + * source. + * @param resourceSpecifications Map of resource name to specification. Must not + * be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceSpecifications is null + * @see #resources(McpServerFeatures.SyncResourceSpecification...) + */ + public SyncSpecification resources( + Map resourceSpecifications) { + Assert.notNull(resourceSpecifications, "Resource handlers map must not be null"); + this.resources.putAll(resourceSpecifications); + return this; + } + + /** + * Registers multiple resources with their handlers using a List. This method is + * useful when resources need to be added in bulk from a collection. + * @param resourceSpecifications List of resource specifications. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceSpecifications is null + * @see #resources(McpServerFeatures.SyncResourceSpecification...) + */ + public SyncSpecification resources( + List resourceSpecifications) { + Assert.notNull(resourceSpecifications, "Resource handlers list must not be null"); + for (McpServerFeatures.SyncResourceSpecification resource : resourceSpecifications) { + this.resources.put(resource.resource().uri(), resource); + } + return this; + } + + /** + * Registers multiple resources with their handlers using varargs. This method + * provides a convenient way to register multiple resources inline. + * + *

+ * Example usage:

{@code
+		 * .resources(
+		 *     new ResourceSpecification(fileResource, fileHandler),
+		 *     new ResourceSpecification(dbResource, dbHandler),
+		 *     new ResourceSpecification(apiResource, apiHandler)
+		 * )
+		 * }
+ * @param resourceSpecifications The resource specifications to add. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceSpecifications is null + */ + public SyncSpecification resources(McpServerFeatures.SyncResourceSpecification... resourceSpecifications) { + Assert.notNull(resourceSpecifications, "Resource handlers list must not be null"); + for (McpServerFeatures.SyncResourceSpecification resource : resourceSpecifications) { + this.resources.put(resource.resource().uri(), resource); + } + return this; + } + + /** + * Sets the resource templates that define patterns for dynamic resource access. + * Templates use URI patterns with placeholders that can be filled at runtime. + * @param resourceTemplates List of resource template specifications. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceTemplates is null. + */ + public SyncSpecification resourceTemplates( + List resourceTemplates) { + Assert.notNull(resourceTemplates, "Resource templates must not be null"); + for (McpServerFeatures.SyncResourceTemplateSpecification resource : resourceTemplates) { + this.resourceTemplates.put(resource.resourceTemplate().uriTemplate(), resource); + } + return this; + } + + /** + * Sets the resource templates using varargs for convenience. This is an + * alternative to {@link #resourceTemplates(List)}. + * @param resourceTemplates The resource templates to set. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceTemplates is null + * @see #resourceTemplates(List) + */ + public SyncSpecification resourceTemplates( + McpServerFeatures.SyncResourceTemplateSpecification... resourceTemplates) { + Assert.notNull(resourceTemplates, "Resource templates must not be null"); + for (McpServerFeatures.SyncResourceTemplateSpecification resourceTemplate : resourceTemplates) { + this.resourceTemplates.put(resourceTemplate.resourceTemplate().uriTemplate(), resourceTemplate); + } + return this; + } + + /** + * Registers multiple prompts with their handlers using a Map. This method is + * useful when prompts are dynamically generated or loaded from a configuration + * source. + * + *

+ * Example usage:

{@code
+		 * Map prompts = new HashMap<>();
+		 * prompts.put("analysis", new PromptSpecification(
+		 *     new Prompt("analysis", "Code analysis template"),
+		 *     (exchange, request) -> new GetPromptResult(generateAnalysisPrompt(request))
+		 * ));
+		 * .prompts(prompts)
+		 * }
+ * @param prompts Map of prompt name to specification. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if prompts is null + */ + public SyncSpecification prompts(Map prompts) { + Assert.notNull(prompts, "Prompts map must not be null"); + this.prompts.putAll(prompts); + return this; + } + + /** + * Registers multiple prompts with their handlers using a List. This method is + * useful when prompts need to be added in bulk from a collection. + * @param prompts List of prompt specifications. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if prompts is null + * @see #prompts(McpServerFeatures.SyncPromptSpecification...) + */ + public SyncSpecification prompts(List prompts) { + Assert.notNull(prompts, "Prompts list must not be null"); + for (McpServerFeatures.SyncPromptSpecification prompt : prompts) { + this.prompts.put(prompt.prompt().name(), prompt); + } + return this; + } + + /** + * Registers multiple prompts with their handlers using varargs. This method + * provides a convenient way to register multiple prompts inline. + * + *

+ * Example usage:

{@code
+		 * .prompts(
+		 *     new PromptSpecification(analysisPrompt, analysisHandler),
+		 *     new PromptSpecification(summaryPrompt, summaryHandler),
+		 *     new PromptSpecification(reviewPrompt, reviewHandler)
+		 * )
+		 * }
+ * @param prompts The prompt specifications to add. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if prompts is null + */ + public SyncSpecification prompts(McpServerFeatures.SyncPromptSpecification... prompts) { + Assert.notNull(prompts, "Prompts list must not be null"); + for (McpServerFeatures.SyncPromptSpecification prompt : prompts) { + this.prompts.put(prompt.prompt().name(), prompt); + } + return this; + } + + /** + * Registers multiple completions with their handlers using a List. This method is + * useful when completions need to be added in bulk from a collection. + * @param completions List of completion specifications. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if completions is null + * @see #completions(McpServerFeatures.SyncCompletionSpecification...) + */ + public SyncSpecification completions(List completions) { + Assert.notNull(completions, "Completions list must not be null"); + for (McpServerFeatures.SyncCompletionSpecification completion : completions) { + this.completions.put(completion.referenceKey(), completion); + } + return this; + } + + /** + * Registers multiple completions with their handlers using varargs. This method + * is useful when completions are defined inline and added directly. + * @param completions Array of completion specifications. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if completions is null + */ + public SyncSpecification completions(McpServerFeatures.SyncCompletionSpecification... completions) { + Assert.notNull(completions, "Completions list must not be null"); + for (McpServerFeatures.SyncCompletionSpecification completion : completions) { + this.completions.put(completion.referenceKey(), completion); + } + return this; + } + + /** + * Registers a consumer that will be notified when the list of roots changes. This + * is useful for updating resource availability dynamically, such as when new + * files are added or removed. + * @param handler The handler to register. Must not be null. The function's first + * argument is an {@link McpSyncServerExchange} upon which the server can interact + * with the connected client. The second argument is the list of roots. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if consumer is null + */ + public SyncSpecification rootsChangeHandler( + BiConsumer> handler) { + Assert.notNull(handler, "Consumer must not be null"); + this.rootsChangeHandlers.add(handler); + return this; + } + + /** + * Registers multiple consumers that will be notified when the list of roots + * changes. This method is useful when multiple consumers need to be registered at + * once. + * @param handlers The list of handlers to register. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if consumers is null + * @see #rootsChangeHandler(BiConsumer) + */ + public SyncSpecification rootsChangeHandlers( + List>> handlers) { + Assert.notNull(handlers, "Handlers list must not be null"); + this.rootsChangeHandlers.addAll(handlers); + return this; + } + + /** + * Registers multiple consumers that will be notified when the list of roots + * changes using varargs. This method provides a convenient way to register + * multiple consumers inline. + * @param handlers The handlers to register. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if consumers is null + * @see #rootsChangeHandlers(List) + */ + public SyncSpecification rootsChangeHandlers( + BiConsumer>... handlers) { + Assert.notNull(handlers, "Handlers list must not be null"); + return this.rootsChangeHandlers(List.of(handlers)); + } + + /** + * Sets the JsonMapper to use for serializing and deserializing JSON messages. + * @param jsonMapper the mapper to use. Must not be null. + * @return This builder instance for method chaining. + * @throws IllegalArgumentException if jsonMapper is null + */ + public SyncSpecification jsonMapper(McpJsonMapper jsonMapper) { + Assert.notNull(jsonMapper, "JsonMapper must not be null"); + this.jsonMapper = jsonMapper; + return this; + } + + public SyncSpecification jsonSchemaValidator(JsonSchemaValidator jsonSchemaValidator) { + Assert.notNull(jsonSchemaValidator, "JsonSchemaValidator must not be null"); + this.jsonSchemaValidator = jsonSchemaValidator; + return this; + } + + /** + * Enable on "immediate execution" of the operations on the underlying + * {@link McpAsyncServer}. Defaults to false, which does blocking code offloading + * to prevent accidental blocking of the non-blocking transport. + *

+ * Do NOT set to true if the underlying transport is a non-blocking + * implementation. + * @param immediateExecution When true, do not offload work asynchronously. + * @return This builder instance for method chaining. + * + */ + public SyncSpecification immediateExecution(boolean immediateExecution) { + this.immediateExecution = immediateExecution; + return this; + } + + } + + class StatelessAsyncSpecification { + + private final McpStatelessServerTransport transport; + + McpUriTemplateManagerFactory uriTemplateManagerFactory = new DefaultMcpUriTemplateManagerFactory(); + + McpJsonMapper jsonMapper; + + McpSchema.Implementation serverInfo = DEFAULT_SERVER_INFO; + + McpSchema.ServerCapabilities serverCapabilities; + + JsonSchemaValidator jsonSchemaValidator; + + String instructions; + + /** + * The Model Context Protocol (MCP) allows servers to expose tools that can be + * invoked by language models. Tools enable models to interact with external + * systems, such as querying databases, calling APIs, or performing computations. + * Each tool is uniquely identified by a name and includes metadata describing its + * schema. + */ + final List tools = new ArrayList<>(); + + /** + * The Model Context Protocol (MCP) provides a standardized way for servers to + * expose resources to clients. Resources allow servers to share data that + * provides context to language models, such as files, database schemas, or + * application-specific information. Each resource is uniquely identified by a + * URI. + */ + final Map resources = new HashMap<>(); + + /** + * The Model Context Protocol (MCP) provides a standardized way for servers to + * expose resource templates to clients. Resource templates allow servers to + * define parameterized URIs that clients can use to access dynamic resources. + * Each resource template includes variables that clients can fill in to form + * concrete resource URIs. + */ + final Map resourceTemplates = new HashMap<>(); + + /** + * The Model Context Protocol (MCP) provides a standardized way for servers to + * expose prompt templates to clients. Prompts allow servers to provide structured + * messages and instructions for interacting with language models. Clients can + * discover available prompts, retrieve their contents, and provide arguments to + * customize them. + */ + final Map prompts = new HashMap<>(); + + final Map completions = new HashMap<>(); + + Duration requestTimeout = Duration.ofSeconds(10); // Default timeout + + public StatelessAsyncSpecification(McpStatelessServerTransport transport) { + this.transport = transport; + } + + /** + * Sets the URI template manager factory to use for creating URI templates. This + * allows for custom URI template parsing and variable extraction. + * @param uriTemplateManagerFactory The factory to use. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if uriTemplateManagerFactory is null + */ + public StatelessAsyncSpecification uriTemplateManagerFactory( + McpUriTemplateManagerFactory uriTemplateManagerFactory) { + Assert.notNull(uriTemplateManagerFactory, "URI template manager factory must not be null"); + this.uriTemplateManagerFactory = uriTemplateManagerFactory; + return this; + } + + /** + * Sets the duration to wait for server responses before timing out requests. This + * timeout applies to all requests made through the client, including tool calls, + * resource access, and prompt operations. + * @param requestTimeout The duration to wait before timing out requests. Must not + * be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if requestTimeout is null + */ + public StatelessAsyncSpecification requestTimeout(Duration requestTimeout) { + Assert.notNull(requestTimeout, "Request timeout must not be null"); + this.requestTimeout = requestTimeout; + return this; + } + + /** + * Sets the server implementation information that will be shared with clients + * during connection initialization. This helps with version compatibility, + * debugging, and server identification. + * @param serverInfo The server implementation details including name and version. + * Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if serverInfo is null + */ + public StatelessAsyncSpecification serverInfo(McpSchema.Implementation serverInfo) { + Assert.notNull(serverInfo, "Server info must not be null"); + this.serverInfo = serverInfo; + return this; + } + + /** + * Sets the server implementation information using name and version strings. This + * is a convenience method alternative to + * {@link #serverInfo(McpSchema.Implementation)}. + * @param name The server name. Must not be null or empty. + * @param version The server version. Must not be null or empty. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if name or version is null or empty + * @see #serverInfo(McpSchema.Implementation) + */ + public StatelessAsyncSpecification serverInfo(String name, String version) { + Assert.hasText(name, "Name must not be null or empty"); + Assert.hasText(version, "Version must not be null or empty"); + this.serverInfo = new McpSchema.Implementation(name, version); + return this; + } + + /** + * Sets the server instructions that will be shared with clients during connection + * initialization. These instructions provide guidance to the client on how to + * interact with this server. + * @param instructions The instructions text. Can be null or empty. + * @return This builder instance for method chaining + */ + public StatelessAsyncSpecification instructions(String instructions) { + this.instructions = instructions; + return this; + } + + /** + * Sets the server capabilities that will be advertised to clients during + * connection initialization. Capabilities define what features the server + * supports, such as: + *

    + *
  • Tool execution + *
  • Resource access + *
  • Prompt handling + *
+ * @param serverCapabilities The server capabilities configuration. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if serverCapabilities is null + */ + public StatelessAsyncSpecification capabilities(McpSchema.ServerCapabilities serverCapabilities) { + Assert.notNull(serverCapabilities, "Server capabilities must not be null"); + this.serverCapabilities = serverCapabilities; + return this; + } + + /** + * Adds a single tool with its implementation handler to the server. This is a + * convenience method for registering individual tools without creating a + * {@link McpServerFeatures.AsyncToolSpecification} explicitly. + * @param tool The tool definition including name, description, and schema. Must + * not be null. + * @param callHandler The function that implements the tool's logic. Must not be + * null. The function's first argument is an {@link McpAsyncServerExchange} upon + * which the server can interact with the connected client. The second argument is + * the {@link McpSchema.CallToolRequest} object containing the tool call + * @return This builder instance for method chaining + * @throws IllegalArgumentException if tool or handler is null + */ + public StatelessAsyncSpecification toolCall(McpSchema.Tool tool, + BiFunction> callHandler) { + + Assert.notNull(tool, "Tool must not be null"); + Assert.notNull(callHandler, "Handler must not be null"); + assertNoDuplicateTool(tool.name()); + + this.tools.add(new McpStatelessServerFeatures.AsyncToolSpecification(tool, callHandler)); + + return this; + } + + /** + * Adds multiple tools with their handlers to the server using a List. This method + * is useful when tools are dynamically generated or loaded from a configuration + * source. + * @param toolSpecifications The list of tool specifications to add. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if toolSpecifications is null + * @see #tools(McpStatelessServerFeatures.AsyncToolSpecification...) + */ + public StatelessAsyncSpecification tools( + List toolSpecifications) { + Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); + + for (var tool : toolSpecifications) { + assertNoDuplicateTool(tool.tool().name()); + this.tools.add(tool); + } + + return this; + } + + /** + * Adds multiple tools with their handlers to the server using varargs. This + * method provides a convenient way to register multiple tools inline. + * + *

+ * Example usage:

{@code
+		 * .tools(
+		 *     McpServerFeatures.AsyncToolSpecification.builder().tool(calculatorTool).callTool(calculatorHandler).build(),
+		 *     McpServerFeatures.AsyncToolSpecification.builder().tool(weatherTool).callTool(weatherHandler).build(),
+		 *     McpServerFeatures.AsyncToolSpecification.builder().tool(fileManagerTool).callTool(fileManagerHandler).build()
+		 * )
+		 * }
+ * @param toolSpecifications The tool specifications to add. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if toolSpecifications is null + */ + public StatelessAsyncSpecification tools( + McpStatelessServerFeatures.AsyncToolSpecification... toolSpecifications) { + Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); + + for (var tool : toolSpecifications) { + assertNoDuplicateTool(tool.tool().name()); + this.tools.add(tool); + } + return this; + } + + private void assertNoDuplicateTool(String toolName) { + if (this.tools.stream().anyMatch(toolSpec -> toolSpec.tool().name().equals(toolName))) { + throw new IllegalArgumentException("Tool with name '" + toolName + "' is already registered."); + } + } + + /** + * Registers multiple resources with their handlers using a Map. This method is + * useful when resources are dynamically generated or loaded from a configuration + * source. + * @param resourceSpecifications Map of resource name to specification. Must not + * be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceSpecifications is null + * @see #resources(McpStatelessServerFeatures.AsyncResourceSpecification...) + */ + public StatelessAsyncSpecification resources( + Map resourceSpecifications) { + Assert.notNull(resourceSpecifications, "Resource handlers map must not be null"); + this.resources.putAll(resourceSpecifications); + return this; + } + + /** + * Registers multiple resources with their handlers using a List. This method is + * useful when resources need to be added in bulk from a collection. + * @param resourceSpecifications List of resource specifications. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceSpecifications is null + * @see #resources(McpStatelessServerFeatures.AsyncResourceSpecification...) + */ + public StatelessAsyncSpecification resources( + List resourceSpecifications) { + Assert.notNull(resourceSpecifications, "Resource handlers list must not be null"); + for (var resource : resourceSpecifications) { + this.resources.put(resource.resource().uri(), resource); + } + return this; + } + + /** + * Registers multiple resources with their handlers using varargs. This method + * provides a convenient way to register multiple resources inline. + * + *

+ * Example usage:

{@code
+		 * .resources(
+		 *     new McpServerFeatures.AsyncResourceSpecification(fileResource, fileHandler),
+		 *     new McpServerFeatures.AsyncResourceSpecification(dbResource, dbHandler),
+		 *     new McpServerFeatures.AsyncResourceSpecification(apiResource, apiHandler)
+		 * )
+		 * }
+ * @param resourceSpecifications The resource specifications to add. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceSpecifications is null + */ + public StatelessAsyncSpecification resources( + McpStatelessServerFeatures.AsyncResourceSpecification... resourceSpecifications) { + Assert.notNull(resourceSpecifications, "Resource handlers list must not be null"); + for (var resource : resourceSpecifications) { + this.resources.put(resource.resource().uri(), resource); + } + return this; + } + + /** + * Sets the resource templates that define patterns for dynamic resource access. + * Templates use URI patterns with placeholders that can be filled at runtime. + * @param resourceTemplates List of resource templates. If null, clears existing + * templates. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceTemplates is null. + */ + public StatelessAsyncSpecification resourceTemplates( + List resourceTemplates) { + Assert.notNull(resourceTemplates, "Resource templates must not be null"); + for (var resourceTemplate : resourceTemplates) { + this.resourceTemplates.put(resourceTemplate.resourceTemplate().uriTemplate(), resourceTemplate); + } + return this; + } + + /** + * Sets the resource templates using varargs for convenience. This is an + * alternative to {@link #resourceTemplates(List)}. + * @param resourceTemplates The resource templates to set. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceTemplates is null. + * @see #resourceTemplates(List) + */ + public StatelessAsyncSpecification resourceTemplates( + McpStatelessServerFeatures.AsyncResourceTemplateSpecification... resourceTemplates) { + Assert.notNull(resourceTemplates, "Resource templates must not be null"); + for (McpStatelessServerFeatures.AsyncResourceTemplateSpecification resourceTemplate : resourceTemplates) { + this.resourceTemplates.put(resourceTemplate.resourceTemplate().uriTemplate(), resourceTemplate); + } + return this; + } + + /** + * Registers multiple prompts with their handlers using a Map. This method is + * useful when prompts are dynamically generated or loaded from a configuration + * source. + * + *

+ * Example usage:

{@code
+		 * .prompts(Map.of("analysis", new McpServerFeatures.AsyncPromptSpecification(
+		 *     new Prompt("analysis", "Code analysis template"),
+		 *     request -> Mono.fromSupplier(() -> generateAnalysisPrompt(request))
+		 *         .map(GetPromptResult::new)
+		 * )));
+		 * }
+ * @param prompts Map of prompt name to specification. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if prompts is null + */ + public StatelessAsyncSpecification prompts( + Map prompts) { + Assert.notNull(prompts, "Prompts map must not be null"); + this.prompts.putAll(prompts); + return this; + } + + /** + * Registers multiple prompts with their handlers using a List. This method is + * useful when prompts need to be added in bulk from a collection. + * @param prompts List of prompt specifications. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if prompts is null + * @see #prompts(McpStatelessServerFeatures.AsyncPromptSpecification...) + */ + public StatelessAsyncSpecification prompts(List prompts) { + Assert.notNull(prompts, "Prompts list must not be null"); + for (var prompt : prompts) { + this.prompts.put(prompt.prompt().name(), prompt); + } + return this; + } + + /** + * Registers multiple prompts with their handlers using varargs. This method + * provides a convenient way to register multiple prompts inline. + * + *

+ * Example usage:

{@code
+		 * .prompts(
+		 *     new McpServerFeatures.AsyncPromptSpecification(analysisPrompt, analysisHandler),
+		 *     new McpServerFeatures.AsyncPromptSpecification(summaryPrompt, summaryHandler),
+		 *     new McpServerFeatures.AsyncPromptSpecification(reviewPrompt, reviewHandler)
+		 * )
+		 * }
+ * @param prompts The prompt specifications to add. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if prompts is null + */ + public StatelessAsyncSpecification prompts(McpStatelessServerFeatures.AsyncPromptSpecification... prompts) { + Assert.notNull(prompts, "Prompts list must not be null"); + for (var prompt : prompts) { + this.prompts.put(prompt.prompt().name(), prompt); + } + return this; + } + + /** + * Registers multiple completions with their handlers using a List. This method is + * useful when completions need to be added in bulk from a collection. + * @param completions List of completion specifications. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if completions is null + */ + public StatelessAsyncSpecification completions( + List completions) { + Assert.notNull(completions, "Completions list must not be null"); + for (var completion : completions) { + this.completions.put(completion.referenceKey(), completion); + } + return this; + } + + /** + * Registers multiple completions with their handlers using varargs. This method + * is useful when completions are defined inline and added directly. + * @param completions Array of completion specifications. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if completions is null + */ + public StatelessAsyncSpecification completions( + McpStatelessServerFeatures.AsyncCompletionSpecification... completions) { + Assert.notNull(completions, "Completions list must not be null"); + for (var completion : completions) { + this.completions.put(completion.referenceKey(), completion); + } + return this; + } + + /** + * Sets the JsonMapper to use for serializing and deserializing JSON messages. + * @param jsonMapper the mapper to use. Must not be null. + * @return This builder instance for method chaining. + * @throws IllegalArgumentException if jsonMapper is null + */ + public StatelessAsyncSpecification jsonMapper(McpJsonMapper jsonMapper) { + Assert.notNull(jsonMapper, "JsonMapper must not be null"); + this.jsonMapper = jsonMapper; + return this; + } + + /** + * Sets the JSON schema validator to use for validating tool and resource schemas. + * This ensures that the server's tools and resources conform to the expected + * schema definitions. + * @param jsonSchemaValidator The validator to use. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if jsonSchemaValidator is null + */ + public StatelessAsyncSpecification jsonSchemaValidator(JsonSchemaValidator jsonSchemaValidator) { + Assert.notNull(jsonSchemaValidator, "JsonSchemaValidator must not be null"); + this.jsonSchemaValidator = jsonSchemaValidator; + return this; + } + + public McpStatelessAsyncServer build() { + var features = new McpStatelessServerFeatures.Async(this.serverInfo, this.serverCapabilities, this.tools, + this.resources, this.resourceTemplates, this.prompts, this.completions, this.instructions); + return new McpStatelessAsyncServer(transport, jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, + features, requestTimeout, uriTemplateManagerFactory, + jsonSchemaValidator != null ? jsonSchemaValidator : JsonSchemaValidator.getDefault()); + } + + } + + class StatelessSyncSpecification { + + private final McpStatelessServerTransport transport; + + boolean immediateExecution = false; + + McpUriTemplateManagerFactory uriTemplateManagerFactory = new DefaultMcpUriTemplateManagerFactory(); + + McpJsonMapper jsonMapper; + + McpSchema.Implementation serverInfo = DEFAULT_SERVER_INFO; + + McpSchema.ServerCapabilities serverCapabilities; + + JsonSchemaValidator jsonSchemaValidator; + + String instructions; + + /** + * The Model Context Protocol (MCP) allows servers to expose tools that can be + * invoked by language models. Tools enable models to interact with external + * systems, such as querying databases, calling APIs, or performing computations. + * Each tool is uniquely identified by a name and includes metadata describing its + * schema. + */ + final List tools = new ArrayList<>(); + + /** + * The Model Context Protocol (MCP) provides a standardized way for servers to + * expose resources to clients. Resources allow servers to share data that + * provides context to language models, such as files, database schemas, or + * application-specific information. Each resource is uniquely identified by a + * URI. + */ + final Map resources = new HashMap<>(); + + /** + * The Model Context Protocol (MCP) provides a standardized way for servers to + * expose resource templates to clients. Resource templates allow servers to + * define parameterized URIs that clients can use to access dynamic resources. + * Each resource template includes variables that clients can fill in to form + * concrete resource URIs. + */ + final Map resourceTemplates = new HashMap<>(); + + /** + * The Model Context Protocol (MCP) provides a standardized way for servers to + * expose prompt templates to clients. Prompts allow servers to provide structured + * messages and instructions for interacting with language models. Clients can + * discover available prompts, retrieve their contents, and provide arguments to + * customize them. + */ + final Map prompts = new HashMap<>(); + + final Map completions = new HashMap<>(); + + Duration requestTimeout = Duration.ofSeconds(10); // Default timeout + + public StatelessSyncSpecification(McpStatelessServerTransport transport) { + this.transport = transport; + } + + /** + * Sets the URI template manager factory to use for creating URI templates. This + * allows for custom URI template parsing and variable extraction. + * @param uriTemplateManagerFactory The factory to use. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if uriTemplateManagerFactory is null + */ + public StatelessSyncSpecification uriTemplateManagerFactory( + McpUriTemplateManagerFactory uriTemplateManagerFactory) { + Assert.notNull(uriTemplateManagerFactory, "URI template manager factory must not be null"); + this.uriTemplateManagerFactory = uriTemplateManagerFactory; + return this; + } + + /** + * Sets the duration to wait for server responses before timing out requests. This + * timeout applies to all requests made through the client, including tool calls, + * resource access, and prompt operations. + * @param requestTimeout The duration to wait before timing out requests. Must not + * be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if requestTimeout is null + */ + public StatelessSyncSpecification requestTimeout(Duration requestTimeout) { + Assert.notNull(requestTimeout, "Request timeout must not be null"); + this.requestTimeout = requestTimeout; + return this; + } + + /** + * Sets the server implementation information that will be shared with clients + * during connection initialization. This helps with version compatibility, + * debugging, and server identification. + * @param serverInfo The server implementation details including name and version. + * Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if serverInfo is null + */ + public StatelessSyncSpecification serverInfo(McpSchema.Implementation serverInfo) { + Assert.notNull(serverInfo, "Server info must not be null"); + this.serverInfo = serverInfo; + return this; + } + + /** + * Sets the server implementation information using name and version strings. This + * is a convenience method alternative to + * {@link #serverInfo(McpSchema.Implementation)}. + * @param name The server name. Must not be null or empty. + * @param version The server version. Must not be null or empty. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if name or version is null or empty + * @see #serverInfo(McpSchema.Implementation) + */ + public StatelessSyncSpecification serverInfo(String name, String version) { + Assert.hasText(name, "Name must not be null or empty"); + Assert.hasText(version, "Version must not be null or empty"); + this.serverInfo = new McpSchema.Implementation(name, version); + return this; + } + + /** + * Sets the server instructions that will be shared with clients during connection + * initialization. These instructions provide guidance to the client on how to + * interact with this server. + * @param instructions The instructions text. Can be null or empty. + * @return This builder instance for method chaining + */ + public StatelessSyncSpecification instructions(String instructions) { + this.instructions = instructions; + return this; + } + + /** + * Sets the server capabilities that will be advertised to clients during + * connection initialization. Capabilities define what features the server + * supports, such as: + *
    + *
  • Tool execution + *
  • Resource access + *
  • Prompt handling + *
+ * @param serverCapabilities The server capabilities configuration. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if serverCapabilities is null + */ + public StatelessSyncSpecification capabilities(McpSchema.ServerCapabilities serverCapabilities) { + Assert.notNull(serverCapabilities, "Server capabilities must not be null"); + this.serverCapabilities = serverCapabilities; + return this; + } + + /** + * Adds a single tool with its implementation handler to the server. This is a + * convenience method for registering individual tools without creating a + * {@link McpServerFeatures.SyncToolSpecification} explicitly. + * @param tool The tool definition including name, description, and schema. Must + * not be null. + * @param callHandler The function that implements the tool's logic. Must not be + * null. The function's first argument is an {@link McpSyncServerExchange} upon + * which the server can interact with the connected client. The second argument is + * the {@link McpSchema.CallToolRequest} object containing the tool call + * @return This builder instance for method chaining + * @throws IllegalArgumentException if tool or handler is null + */ + public StatelessSyncSpecification toolCall(McpSchema.Tool tool, + BiFunction callHandler) { + + Assert.notNull(tool, "Tool must not be null"); + Assert.notNull(callHandler, "Handler must not be null"); + assertNoDuplicateTool(tool.name()); + + this.tools.add(new McpStatelessServerFeatures.SyncToolSpecification(tool, callHandler)); + + return this; + } + + /** + * Adds multiple tools with their handlers to the server using a List. This method + * is useful when tools are dynamically generated or loaded from a configuration + * source. + * @param toolSpecifications The list of tool specifications to add. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if toolSpecifications is null + * @see #tools(McpStatelessServerFeatures.SyncToolSpecification...) + */ + public StatelessSyncSpecification tools( + List toolSpecifications) { + Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); + + for (var tool : toolSpecifications) { + assertNoDuplicateTool(tool.tool().name()); + this.tools.add(tool); + } + + return this; + } + + /** + * Adds multiple tools with their handlers to the server using varargs. This + * method provides a convenient way to register multiple tools inline. + * + *

+ * Example usage:

{@code
+		 * .tools(
+		 *     McpServerFeatures.SyncToolSpecification.builder().tool(calculatorTool).callTool(calculatorHandler).build(),
+		 *     McpServerFeatures.SyncToolSpecification.builder().tool(weatherTool).callTool(weatherHandler).build(),
+		 *     McpServerFeatures.SyncToolSpecification.builder().tool(fileManagerTool).callTool(fileManagerHandler).build()
+		 * )
+		 * }
+ * @param toolSpecifications The tool specifications to add. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if toolSpecifications is null + */ + public StatelessSyncSpecification tools( + McpStatelessServerFeatures.SyncToolSpecification... toolSpecifications) { + Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); + + for (var tool : toolSpecifications) { + assertNoDuplicateTool(tool.tool().name()); + this.tools.add(tool); + } + return this; + } + + private void assertNoDuplicateTool(String toolName) { + if (this.tools.stream().anyMatch(toolSpec -> toolSpec.tool().name().equals(toolName))) { + throw new IllegalArgumentException("Tool with name '" + toolName + "' is already registered."); + } + } + + /** + * Registers multiple resources with their handlers using a Map. This method is + * useful when resources are dynamically generated or loaded from a configuration + * source. + * @param resourceSpecifications Map of resource name to specification. Must not + * be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceSpecifications is null + * @see #resources(McpStatelessServerFeatures.SyncResourceSpecification...) + */ + public StatelessSyncSpecification resources( + Map resourceSpecifications) { + Assert.notNull(resourceSpecifications, "Resource handlers map must not be null"); + this.resources.putAll(resourceSpecifications); + return this; + } + + /** + * Registers multiple resources with their handlers using a List. This method is + * useful when resources need to be added in bulk from a collection. + * @param resourceSpecifications List of resource specifications. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceSpecifications is null + * @see #resources(McpStatelessServerFeatures.SyncResourceSpecification...) + */ + public StatelessSyncSpecification resources( + List resourceSpecifications) { + Assert.notNull(resourceSpecifications, "Resource handlers list must not be null"); + for (var resource : resourceSpecifications) { + this.resources.put(resource.resource().uri(), resource); + } + return this; + } + + /** + * Registers multiple resources with their handlers using varargs. This method + * provides a convenient way to register multiple resources inline. + * + *

+ * Example usage:

{@code
+		 * .resources(
+		 *     new McpServerFeatures.SyncResourceSpecification(fileResource, fileHandler),
+		 *     new McpServerFeatures.SyncResourceSpecification(dbResource, dbHandler),
+		 *     new McpServerFeatures.SyncResourceSpecification(apiResource, apiHandler)
+		 * )
+		 * }
+ * @param resourceSpecifications The resource specifications to add. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceSpecifications is null + */ + public StatelessSyncSpecification resources( + McpStatelessServerFeatures.SyncResourceSpecification... resourceSpecifications) { + Assert.notNull(resourceSpecifications, "Resource handlers list must not be null"); + for (var resource : resourceSpecifications) { + this.resources.put(resource.resource().uri(), resource); + } + return this; + } + + /** + * Sets the resource templates that define patterns for dynamic resource access. + * Templates use URI patterns with placeholders that can be filled at runtime. + * @param resourceTemplatesSpec List of resource templates. If null, clears + * existing templates. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceTemplates is null. + */ + public StatelessSyncSpecification resourceTemplates( + List resourceTemplatesSpec) { + Assert.notNull(resourceTemplatesSpec, "Resource templates must not be null"); + for (var resourceTemplate : resourceTemplatesSpec) { + this.resourceTemplates.put(resourceTemplate.resourceTemplate().uriTemplate(), resourceTemplate); + } + return this; + } + + /** + * Sets the resource templates using varargs for convenience. This is an + * alternative to {@link #resourceTemplates(List)}. + * @param resourceTemplates The resource templates to set. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceTemplates is null. + * @see #resourceTemplates(List) + */ + public StatelessSyncSpecification resourceTemplates( + McpStatelessServerFeatures.SyncResourceTemplateSpecification... resourceTemplates) { + Assert.notNull(resourceTemplates, "Resource templates must not be null"); + for (McpStatelessServerFeatures.SyncResourceTemplateSpecification resourceTemplate : resourceTemplates) { + this.resourceTemplates.put(resourceTemplate.resourceTemplate().uriTemplate(), resourceTemplate); + } + return this; + } + + /** + * Registers multiple prompts with their handlers using a Map. This method is + * useful when prompts are dynamically generated or loaded from a configuration + * source. + * + *

+ * Example usage:

{@code
+		 * .prompts(Map.of("analysis", new McpServerFeatures.SyncPromptSpecification(
+		 *     new Prompt("analysis", "Code analysis template"),
+		 *     request -> Mono.fromSupplier(() -> generateAnalysisPrompt(request))
+		 *         .map(GetPromptResult::new)
+		 * )));
+		 * }
+ * @param prompts Map of prompt name to specification. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if prompts is null + */ + public StatelessSyncSpecification prompts( + Map prompts) { + Assert.notNull(prompts, "Prompts map must not be null"); + this.prompts.putAll(prompts); + return this; + } + + /** + * Registers multiple prompts with their handlers using a List. This method is + * useful when prompts need to be added in bulk from a collection. + * @param prompts List of prompt specifications. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if prompts is null + * @see #prompts(McpStatelessServerFeatures.SyncPromptSpecification...) + */ + public StatelessSyncSpecification prompts(List prompts) { + Assert.notNull(prompts, "Prompts list must not be null"); + for (var prompt : prompts) { + this.prompts.put(prompt.prompt().name(), prompt); + } + return this; + } + + /** + * Registers multiple prompts with their handlers using varargs. This method + * provides a convenient way to register multiple prompts inline. + * + *

+ * Example usage:

{@code
+		 * .prompts(
+		 *     new McpServerFeatures.SyncPromptSpecification(analysisPrompt, analysisHandler),
+		 *     new McpServerFeatures.SyncPromptSpecification(summaryPrompt, summaryHandler),
+		 *     new McpServerFeatures.SyncPromptSpecification(reviewPrompt, reviewHandler)
+		 * )
+		 * }
+ * @param prompts The prompt specifications to add. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if prompts is null + */ + public StatelessSyncSpecification prompts(McpStatelessServerFeatures.SyncPromptSpecification... prompts) { + Assert.notNull(prompts, "Prompts list must not be null"); + for (var prompt : prompts) { + this.prompts.put(prompt.prompt().name(), prompt); + } + return this; + } + + /** + * Registers multiple completions with their handlers using a List. This method is + * useful when completions need to be added in bulk from a collection. + * @param completions List of completion specifications. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if completions is null + */ + public StatelessSyncSpecification completions( + List completions) { + Assert.notNull(completions, "Completions list must not be null"); + for (var completion : completions) { + this.completions.put(completion.referenceKey(), completion); + } + return this; + } + + /** + * Registers multiple completions with their handlers using varargs. This method + * is useful when completions are defined inline and added directly. + * @param completions Array of completion specifications. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if completions is null + */ + public StatelessSyncSpecification completions( + McpStatelessServerFeatures.SyncCompletionSpecification... completions) { + Assert.notNull(completions, "Completions list must not be null"); + for (var completion : completions) { + this.completions.put(completion.referenceKey(), completion); + } + return this; + } + + /** + * Sets the JsonMapper to use for serializing and deserializing JSON messages. + * @param jsonMapper the mapper to use. Must not be null. + * @return This builder instance for method chaining. + * @throws IllegalArgumentException if jsonMapper is null + */ + public StatelessSyncSpecification jsonMapper(McpJsonMapper jsonMapper) { + Assert.notNull(jsonMapper, "JsonMapper must not be null"); + this.jsonMapper = jsonMapper; + return this; + } + + /** + * Sets the JSON schema validator to use for validating tool and resource schemas. + * This ensures that the server's tools and resources conform to the expected + * schema definitions. + * @param jsonSchemaValidator The validator to use. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if jsonSchemaValidator is null + */ + public StatelessSyncSpecification jsonSchemaValidator(JsonSchemaValidator jsonSchemaValidator) { + Assert.notNull(jsonSchemaValidator, "JsonSchemaValidator must not be null"); + this.jsonSchemaValidator = jsonSchemaValidator; + return this; + } + + /** + * Enable on "immediate execution" of the operations on the underlying + * {@link McpStatelessAsyncServer}. Defaults to false, which does blocking code + * offloading to prevent accidental blocking of the non-blocking transport. + *

+ * Do NOT set to true if the underlying transport is a non-blocking + * implementation. + * @param immediateExecution When true, do not offload work asynchronously. + * @return This builder instance for method chaining. + * + */ + public StatelessSyncSpecification immediateExecution(boolean immediateExecution) { + this.immediateExecution = immediateExecution; + return this; + } + + public McpStatelessSyncServer build() { + var syncFeatures = new McpStatelessServerFeatures.Sync(this.serverInfo, this.serverCapabilities, this.tools, + this.resources, this.resourceTemplates, this.prompts, this.completions, this.instructions); + var asyncFeatures = McpStatelessServerFeatures.Async.fromSync(syncFeatures, this.immediateExecution); + var asyncServer = new McpStatelessAsyncServer(transport, + jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, asyncFeatures, requestTimeout, + uriTemplateManagerFactory, + this.jsonSchemaValidator != null ? this.jsonSchemaValidator : JsonSchemaValidator.getDefault()); + return new McpStatelessSyncServer(asyncServer, this.immediateExecution); + } + + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java similarity index 53% rename from mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java rename to mcp-core/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java index 8311f5d41..fe0608b1c 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2024-2025 the original author or authors. */ package io.modelcontextprotocol.server; @@ -12,6 +12,7 @@ import java.util.function.BiFunction; import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; import io.modelcontextprotocol.util.Assert; import io.modelcontextprotocol.util.Utils; import reactor.core.publisher.Mono; @@ -40,7 +41,7 @@ public class McpServerFeatures { */ record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, List tools, Map resources, - List resourceTemplates, + Map resourceTemplates, Map prompts, Map completions, List, Mono>> rootsChangeConsumers, @@ -52,7 +53,7 @@ record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities s * @param serverCapabilities The server capabilities * @param tools The list of tool specifications * @param resources The map of resource specifications - * @param resourceTemplates The list of resource templates + * @param resourceTemplates The map of resource templates * @param prompts The map of prompt specifications * @param rootsChangeConsumers The list of consumers that will be notified when * the roots list changes @@ -60,7 +61,7 @@ record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities s */ Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, List tools, Map resources, - List resourceTemplates, + Map resourceTemplates, Map prompts, Map completions, List, Mono>> rootsChangeConsumers, @@ -83,7 +84,7 @@ record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities s this.tools = (tools != null) ? tools : List.of(); this.resources = (resources != null) ? resources : Map.of(); - this.resourceTemplates = (resourceTemplates != null) ? resourceTemplates : List.of(); + this.resourceTemplates = (resourceTemplates != null) ? resourceTemplates : Map.of(); this.prompts = (prompts != null) ? prompts : Map.of(); this.completions = (completions != null) ? completions : Map.of(); this.rootsChangeConsumers = (rootsChangeConsumers != null) ? rootsChangeConsumers : List.of(); @@ -95,28 +96,35 @@ record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities s * blocking code offloading to prevent accidental blocking of the non-blocking * transport. * @param syncSpec a potentially blocking, synchronous specification. + * @param immediateExecution when true, do not offload. Do NOT set to true when + * using a non-blocking transport. * @return a specification which is protected from blocking calls specified by the * user. */ - static Async fromSync(Sync syncSpec) { + static Async fromSync(Sync syncSpec, boolean immediateExecution) { List tools = new ArrayList<>(); for (var tool : syncSpec.tools()) { - tools.add(AsyncToolSpecification.fromSync(tool)); + tools.add(AsyncToolSpecification.fromSync(tool, immediateExecution)); } Map resources = new HashMap<>(); syncSpec.resources().forEach((key, resource) -> { - resources.put(key, AsyncResourceSpecification.fromSync(resource)); + resources.put(key, AsyncResourceSpecification.fromSync(resource, immediateExecution)); + }); + + Map resourceTemplates = new HashMap<>(); + syncSpec.resourceTemplates().forEach((key, resource) -> { + resourceTemplates.put(key, AsyncResourceTemplateSpecification.fromSync(resource, immediateExecution)); }); Map prompts = new HashMap<>(); syncSpec.prompts().forEach((key, prompt) -> { - prompts.put(key, AsyncPromptSpecification.fromSync(prompt)); + prompts.put(key, AsyncPromptSpecification.fromSync(prompt, immediateExecution)); }); Map completions = new HashMap<>(); syncSpec.completions().forEach((key, completion) -> { - completions.put(key, AsyncCompletionSpecification.fromSync(completion)); + completions.put(key, AsyncCompletionSpecification.fromSync(completion, immediateExecution)); }); List, Mono>> rootChangeConsumers = new ArrayList<>(); @@ -127,8 +135,8 @@ static Async fromSync(Sync syncSpec) { .subscribeOn(Schedulers.boundedElastic())); } - return new Async(syncSpec.serverInfo(), syncSpec.serverCapabilities(), tools, resources, - syncSpec.resourceTemplates(), prompts, completions, rootChangeConsumers, syncSpec.instructions()); + return new Async(syncSpec.serverInfo(), syncSpec.serverCapabilities(), tools, resources, resourceTemplates, + prompts, completions, rootChangeConsumers, syncSpec.instructions()); } } @@ -148,7 +156,7 @@ static Async fromSync(Sync syncSpec) { record Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, List tools, Map resources, - List resourceTemplates, + Map resourceTemplates, Map prompts, Map completions, List>> rootsChangeConsumers, String instructions) { @@ -168,7 +176,7 @@ record Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities se Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, List tools, Map resources, - List resourceTemplates, + Map resourceTemplates, Map prompts, Map completions, List>> rootsChangeConsumers, @@ -191,7 +199,7 @@ record Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities se this.tools = (tools != null) ? tools : new ArrayList<>(); this.resources = (resources != null) ? resources : new HashMap<>(); - this.resourceTemplates = (resourceTemplates != null) ? resourceTemplates : new ArrayList<>(); + this.resourceTemplates = (resourceTemplates != null) ? resourceTemplates : Map.of(); this.prompts = (prompts != null) ? prompts : new HashMap<>(); this.completions = (completions != null) ? completions : new HashMap<>(); this.rootsChangeConsumers = (rootsChangeConsumers != null) ? rootsChangeConsumers : new ArrayList<>(); @@ -203,51 +211,110 @@ record Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities se /** * Specification of a tool with its asynchronous handler function. Tools are the * primary way for MCP servers to expose functionality to AI models. Each tool - * represents a specific capability, such as: - *

    - *
  • Performing calculations - *
  • Accessing external APIs - *
  • Querying databases - *
  • Manipulating files - *
  • Executing system commands - *
- * - *

- * Example tool specification:

{@code
-	 * new McpServerFeatures.AsyncToolSpecification(
-	 *     new Tool(
-	 *         "calculator",
-	 *         "Performs mathematical calculations",
-	 *         new JsonSchemaObject()
-	 *             .required("expression")
-	 *             .property("expression", JsonSchemaType.STRING)
-	 *     ),
-	 *     (exchange, args) -> {
-	 *         String expr = (String) args.get("expression");
-	 *         return Mono.fromSupplier(() -> evaluate(expr))
-	 *             .map(result -> new CallToolResult("Result: " + result));
-	 *     }
-	 * )
-	 * }
+ * represents a specific capability. * * @param tool The tool definition including name, description, and parameter schema - * @param call The function that implements the tool's logic, receiving arguments and - * returning results. The function's first argument is an - * {@link McpAsyncServerExchange} upon which the server can interact with the - * connected client. The second arguments is a map of tool arguments. + * @param call Deprecated. Use the {@link AsyncToolSpecification#callHandler} instead. + * @param callHandler The function that implements the tool's logic, receiving a + * {@link McpAsyncServerExchange} and a + * {@link io.modelcontextprotocol.spec.McpSchema.CallToolRequest} and returning + * results. The function's first argument is an {@link McpAsyncServerExchange} upon + * which the server can interact with the connected client. The second arguments is a + * map of tool arguments. */ public record AsyncToolSpecification(McpSchema.Tool tool, - BiFunction, Mono> call) { + @Deprecated BiFunction, Mono> call, + BiFunction> callHandler) { + + /** + * @deprecated Use {@link AsyncToolSpecification(McpSchema.Tool, null, + * BiFunction)} instead. + **/ + @Deprecated + public AsyncToolSpecification(McpSchema.Tool tool, + BiFunction, Mono> call) { + this(tool, call, (exchange, toolReq) -> call.apply(exchange, toolReq.arguments())); + } + + static AsyncToolSpecification fromSync(SyncToolSpecification syncToolSpec) { + return fromSync(syncToolSpec, false); + } + + static AsyncToolSpecification fromSync(SyncToolSpecification syncToolSpec, boolean immediate) { - static AsyncToolSpecification fromSync(SyncToolSpecification tool) { // FIXME: This is temporary, proper validation should be implemented - if (tool == null) { + if (syncToolSpec == null) { return null; } - return new AsyncToolSpecification(tool.tool(), - (exchange, map) -> Mono - .fromCallable(() -> tool.call().apply(new McpSyncServerExchange(exchange), map)) - .subscribeOn(Schedulers.boundedElastic())); + + BiFunction, Mono> deprecatedCall = (syncToolSpec + .call() != null) ? (exchange, map) -> { + var toolResult = Mono + .fromCallable(() -> syncToolSpec.call().apply(new McpSyncServerExchange(exchange), map)); + return immediate ? toolResult : toolResult.subscribeOn(Schedulers.boundedElastic()); + } : null; + + BiFunction> callHandler = ( + exchange, req) -> { + var toolResult = Mono + .fromCallable(() -> syncToolSpec.callHandler().apply(new McpSyncServerExchange(exchange), req)); + return immediate ? toolResult : toolResult.subscribeOn(Schedulers.boundedElastic()); + }; + + return new AsyncToolSpecification(syncToolSpec.tool(), deprecatedCall, callHandler); + } + + /** + * Builder for creating AsyncToolSpecification instances. + */ + public static class Builder { + + private McpSchema.Tool tool; + + private BiFunction> callHandler; + + /** + * Sets the tool definition. + * @param tool The tool definition including name, description, and parameter + * schema + * @return this builder instance + */ + public Builder tool(McpSchema.Tool tool) { + this.tool = tool; + return this; + } + + /** + * Sets the call tool handler function. + * @param callHandler The function that implements the tool's logic + * @return this builder instance + */ + public Builder callHandler( + BiFunction> callHandler) { + this.callHandler = callHandler; + return this; + } + + /** + * Builds the AsyncToolSpecification instance. + * @return a new AsyncToolSpecification instance + * @throws IllegalArgumentException if required fields are not set + */ + public AsyncToolSpecification build() { + Assert.notNull(tool, "Tool must not be null"); + Assert.notNull(callHandler, "Call handler function must not be null"); + + return new AsyncToolSpecification(tool, null, callHandler); + } + + } + + /** + * Creates a new builder instance. + * @return a new Builder instance + */ + public static Builder builder() { + return new Builder(); } } @@ -263,13 +330,19 @@ static AsyncToolSpecification fromSync(SyncToolSpecification tool) { * * *

- * Example resource specification:

{@code
+	 * Example resource specification:
+	 *
+	 * 
{@code
 	 * new McpServerFeatures.AsyncResourceSpecification(
-	 *     new Resource("docs", "Documentation files", "text/markdown"),
-	 *     (exchange, request) ->
-	 *         Mono.fromSupplier(() -> readFile(request.getPath()))
-	 *             .map(ReadResourceResult::new)
-	 * )
+	 *     Resource.builder()
+	 *         .uri("docs")
+	 *         .name("Documentation files")
+	 * 		   .title("Documentation files")
+	 * 		   .mimeType("text/markdown")
+	 * 		   .description("Markdown documentation files")
+	 * 		   .build(),
+	 * 		(exchange, request) -> Mono.fromSupplier(() -> readFile(request.getPath()))
+	 * 				.map(ReadResourceResult::new))
 	 * }
* * @param resource The resource definition including name, description, and MIME type @@ -281,15 +354,57 @@ static AsyncToolSpecification fromSync(SyncToolSpecification tool) { public record AsyncResourceSpecification(McpSchema.Resource resource, BiFunction> readHandler) { - static AsyncResourceSpecification fromSync(SyncResourceSpecification resource) { + static AsyncResourceSpecification fromSync(SyncResourceSpecification resource, boolean immediateExecution) { + // FIXME: This is temporary, proper validation should be implemented + if (resource == null) { + return null; + } + return new AsyncResourceSpecification(resource.resource(), (exchange, req) -> { + var resourceResult = Mono + .fromCallable(() -> resource.readHandler().apply(new McpSyncServerExchange(exchange), req)); + return immediateExecution ? resourceResult : resourceResult.subscribeOn(Schedulers.boundedElastic()); + }); + } + } + + /** + * Specification of a resource template with its synchronous handler function. + * Resource templates allow servers to expose parameterized resources using URI + * templates: URI + * templates.. Arguments may be auto-completed through the + * completion API. + * + * Templates support: + *
    + *
  • Parameterized resource definitions + *
  • Dynamic content generation + *
  • Consistent resource formatting + *
  • Contextual data injection + *
+ * + * @param resourceTemplate The resource template definition including name, + * description, and parameter schema + * @param readHandler The function that handles resource read requests. The function's + * first argument is an {@link McpSyncServerExchange} upon which the server can + * interact with the connected client. The second arguments is a + * {@link McpSchema.ReadResourceRequest}. {@link McpSchema.ResourceTemplate} + * {@link McpSchema.ReadResourceResult} + */ + public record AsyncResourceTemplateSpecification(McpSchema.ResourceTemplate resourceTemplate, + BiFunction> readHandler) { + + static AsyncResourceTemplateSpecification fromSync(SyncResourceTemplateSpecification resource, + boolean immediateExecution) { // FIXME: This is temporary, proper validation should be implemented if (resource == null) { return null; } - return new AsyncResourceSpecification(resource.resource(), - (exchange, req) -> Mono - .fromCallable(() -> resource.readHandler().apply(new McpSyncServerExchange(exchange), req)) - .subscribeOn(Schedulers.boundedElastic())); + return new AsyncResourceTemplateSpecification(resource.resourceTemplate(), (exchange, req) -> { + var resourceResult = Mono + .fromCallable(() -> resource.readHandler().apply(new McpSyncServerExchange(exchange), req)); + return immediateExecution ? resourceResult : resourceResult.subscribeOn(Schedulers.boundedElastic()); + }); } } @@ -305,16 +420,16 @@ static AsyncResourceSpecification fromSync(SyncResourceSpecification resource) { * * *

- * Example prompt specification:

{@code
+	 * Example prompt specification:
+	 *
+	 * 
{@code
 	 * new McpServerFeatures.AsyncPromptSpecification(
-	 *     new Prompt("analyze", "Code analysis template"),
-	 *     (exchange, request) -> {
-	 *         String code = request.getArguments().get("code");
-	 *         return Mono.just(new GetPromptResult(
-	 *             "Analyze this code:\n\n" + code + "\n\nProvide feedback on:"
-	 *         ));
-	 *     }
-	 * )
+	 * 		new Prompt("analyze", "Code analysis template"),
+	 * 		(exchange, request) -> {
+	 * 			String code = request.getArguments().get("code");
+	 * 			return Mono.just(new GetPromptResult(
+	 * 					"Analyze this code:\n\n" + code + "\n\nProvide feedback on:"));
+	 * 		})
 	 * }
* * @param prompt The prompt definition including name and description @@ -327,15 +442,16 @@ static AsyncResourceSpecification fromSync(SyncResourceSpecification resource) { public record AsyncPromptSpecification(McpSchema.Prompt prompt, BiFunction> promptHandler) { - static AsyncPromptSpecification fromSync(SyncPromptSpecification prompt) { + static AsyncPromptSpecification fromSync(SyncPromptSpecification prompt, boolean immediateExecution) { // FIXME: This is temporary, proper validation should be implemented if (prompt == null) { return null; } - return new AsyncPromptSpecification(prompt.prompt(), - (exchange, req) -> Mono - .fromCallable(() -> prompt.promptHandler().apply(new McpSyncServerExchange(exchange), req)) - .subscribeOn(Schedulers.boundedElastic())); + return new AsyncPromptSpecification(prompt.prompt(), (exchange, req) -> { + var promptResult = Mono + .fromCallable(() -> prompt.promptHandler().apply(new McpSyncServerExchange(exchange), req)); + return immediateExecution ? promptResult : promptResult.subscribeOn(Schedulers.boundedElastic()); + }); } } @@ -366,54 +482,119 @@ public record AsyncCompletionSpecification(McpSchema.CompleteReference reference * @return an asynchronous wrapper of the provided sync specification, or * {@code null} if input is null */ - static AsyncCompletionSpecification fromSync(SyncCompletionSpecification completion) { + static AsyncCompletionSpecification fromSync(SyncCompletionSpecification completion, + boolean immediateExecution) { if (completion == null) { return null; } - return new AsyncCompletionSpecification(completion.referenceKey(), - (exchange, request) -> Mono.fromCallable( - () -> completion.completionHandler().apply(new McpSyncServerExchange(exchange), request)) - .subscribeOn(Schedulers.boundedElastic())); + return new AsyncCompletionSpecification(completion.referenceKey(), (exchange, request) -> { + var completionResult = Mono.fromCallable( + () -> completion.completionHandler().apply(new McpSyncServerExchange(exchange), request)); + return immediateExecution ? completionResult + : completionResult.subscribeOn(Schedulers.boundedElastic()); + }); } } /** * Specification of a tool with its synchronous handler function. Tools are the - * primary way for MCP servers to expose functionality to AI models. Each tool - * represents a specific capability, such as: - *
    - *
  • Performing calculations - *
  • Accessing external APIs - *
  • Querying databases - *
  • Manipulating files - *
  • Executing system commands - *
+ * primary way for MCP servers to expose functionality to AI models. * *

- * Example tool specification:

{@code
-	 * new McpServerFeatures.SyncToolSpecification(
-	 *     new Tool(
-	 *         "calculator",
-	 *         "Performs mathematical calculations",
-	 *         new JsonSchemaObject()
-	 *             .required("expression")
-	 *             .property("expression", JsonSchemaType.STRING)
-	 *     ),
-	 *     (exchange, args) -> {
-	 *         String expr = (String) args.get("expression");
-	 *         return new CallToolResult("Result: " + evaluate(expr));
-	 *     }
-	 * )
+	 * Example tool specification:
+	 *
+	 * 
{@code
+	 * McpServerFeatures.SyncToolSpecification.builder()
+	 * 		.tool(Tool.builder()
+	 * 				.name("calculator")
+	 * 				.title("Performs mathematical calculations")
+	 * 				.inputSchema(new JsonSchemaObject()
+	 * 						.required("expression")
+	 * 						.property("expression", JsonSchemaType.STRING))
+	 * 				.build()
+	 * 		.toolHandler((exchange, req) -> {
+	 * 			String expr = (String) req.arguments().get("expression");
+	 * 			return CallToolResult.builder()
+	 *                   .content(List.of(new McpSchema.TextContent("Result: " + evaluate(expr))))
+	 *                   .isError(false)
+	 *                   .build();
+	 * 		}))
+	 *      .build();
 	 * }
* * @param tool The tool definition including name, description, and parameter schema - * @param call The function that implements the tool's logic, receiving arguments and - * returning results. The function's first argument is an + * @param call (Deprected) The function that implements the tool's logic, receiving + * arguments and returning results. The function's first argument is an * {@link McpSyncServerExchange} upon which the server can interact with the connected - * client. The second arguments is a map of arguments passed to the tool. + * @param callHandler The function that implements the tool's logic, receiving a + * {@link McpSyncServerExchange} and a + * {@link io.modelcontextprotocol.spec.McpSchema.CallToolRequest} and returning + * results. The function's first argument is an {@link McpSyncServerExchange} upon + * which the server can interact with the client. The second arguments is a map of + * arguments passed to the tool. */ public record SyncToolSpecification(McpSchema.Tool tool, - BiFunction, McpSchema.CallToolResult> call) { + @Deprecated BiFunction, McpSchema.CallToolResult> call, + BiFunction callHandler) { + + @Deprecated + public SyncToolSpecification(McpSchema.Tool tool, + BiFunction, McpSchema.CallToolResult> call) { + this(tool, call, (exchange, toolReq) -> call.apply(exchange, toolReq.arguments())); + } + + /** + * Builder for creating SyncToolSpecification instances. + */ + public static class Builder { + + private McpSchema.Tool tool; + + private BiFunction callHandler; + + /** + * Sets the tool definition. + * @param tool The tool definition including name, description, and parameter + * schema + * @return this builder instance + */ + public Builder tool(McpSchema.Tool tool) { + this.tool = tool; + return this; + } + + /** + * Sets the call tool handler function. + * @param callHandler The function that implements the tool's logic + * @return this builder instance + */ + public Builder callHandler( + BiFunction callHandler) { + this.callHandler = callHandler; + return this; + } + + /** + * Builds the SyncToolSpecification instance. + * @return a new SyncToolSpecification instance + * @throws IllegalArgumentException if required fields are not set + */ + public SyncToolSpecification build() { + Assert.notNull(tool, "Tool must not be null"); + Assert.notNull(callHandler, "CallTool function must not be null"); + + return new SyncToolSpecification(tool, null, callHandler); + } + + } + + /** + * Creates a new builder instance. + * @return a new Builder instance + */ + public static Builder builder() { + return new Builder(); + } } /** @@ -428,14 +609,21 @@ public record SyncToolSpecification(McpSchema.Tool tool, * * *

- * Example resource specification:

{@code
+	 * Example resource specification:
+	 *
+	 * 
{@code
 	 * new McpServerFeatures.SyncResourceSpecification(
-	 *     new Resource("docs", "Documentation files", "text/markdown"),
-	 *     (exchange, request) -> {
-	 *         String content = readFile(request.getPath());
-	 *         return new ReadResourceResult(content);
-	 *     }
-	 * )
+	 *     Resource.builder()
+	 *         .uri("docs")
+	 *         .name("Documentation files")
+	 * 		   .title("Documentation files")
+	 * 		   .mimeType("text/markdown")
+	 * 		   .description("Markdown documentation files")
+	 * 		   .build(),
+	 * 		(exchange, request) -> {
+	 * 			String content = readFile(request.getPath());
+	 * 			return new ReadResourceResult(content);
+	 * 		})
 	 * }
* * @param resource The resource definition including name, description, and MIME type @@ -448,6 +636,34 @@ public record SyncResourceSpecification(McpSchema.Resource resource, BiFunction readHandler) { } + /** + * Specification of a resource template with its synchronous handler function. + * Resource templates allow servers to expose parameterized resources using URI + * templates: URI + * templates.. Arguments may be auto-completed through the + * completion API. + * + * Templates support: + *
    + *
  • Parameterized resource definitions + *
  • Dynamic content generation + *
  • Consistent resource formatting + *
  • Contextual data injection + *
+ * + * @param resourceTemplate The resource template definition including name, + * description, and parameter schema + * @param readHandler The function that handles resource read requests. The function's + * first argument is an {@link McpSyncServerExchange} upon which the server can + * interact with the connected client. The second arguments is a + * {@link McpSchema.ReadResourceRequest}. {@link McpSchema.ResourceTemplate} + * {@link McpSchema.ReadResourceResult} + */ + public record SyncResourceTemplateSpecification(McpSchema.ResourceTemplate resourceTemplate, + BiFunction readHandler) { + } + /** * Specification of a prompt template with its synchronous handler function. Prompts * provide structured templates for AI model interactions, supporting: @@ -460,16 +676,16 @@ public record SyncResourceSpecification(McpSchema.Resource resource, * * *

- * Example prompt specification:

{@code
+	 * Example prompt specification:
+	 *
+	 * 
{@code
 	 * new McpServerFeatures.SyncPromptSpecification(
-	 *     new Prompt("analyze", "Code analysis template"),
-	 *     (exchange, request) -> {
-	 *         String code = request.getArguments().get("code");
-	 *         return new GetPromptResult(
-	 *             "Analyze this code:\n\n" + code + "\n\nProvide feedback on:"
-	 *         );
-	 *     }
-	 * )
+	 * 		new Prompt("analyze", "Code analysis template"),
+	 * 		(exchange, request) -> {
+	 * 			String code = request.getArguments().get("code");
+	 * 			return new GetPromptResult(
+	 * 					"Analyze this code:\n\n" + code + "\n\nProvide feedback on:");
+	 * 		})
 	 * }
* * @param prompt The prompt definition including name and description diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessAsyncServer.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessAsyncServer.java new file mode 100644 index 000000000..c7a1fd0d7 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessAsyncServer.java @@ -0,0 +1,851 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import io.modelcontextprotocol.json.TypeRef; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.json.schema.JsonSchemaValidator; +import io.modelcontextprotocol.server.McpStatelessServerFeatures.AsyncResourceTemplateSpecification; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.CompleteResult.CompleteCompletion; +import io.modelcontextprotocol.spec.McpSchema.ErrorCodes; +import io.modelcontextprotocol.spec.McpSchema.PromptReference; +import io.modelcontextprotocol.spec.McpSchema.ResourceReference; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.spec.McpStatelessServerTransport; +import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.DefaultMcpUriTemplateManagerFactory; +import io.modelcontextprotocol.util.McpUriTemplateManagerFactory; +import io.modelcontextprotocol.util.Utils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.function.BiFunction; + +import static io.modelcontextprotocol.spec.McpError.RESOURCE_NOT_FOUND; + +/** + * A stateless MCP server implementation for use with Streamable HTTP transport types. It + * allows simple horizontal scalability since it does not maintain a session and does not + * require initialization. Each instance of the server can be reached with no prior + * knowledge and can serve the clients with the capabilities it supports. + * + * @author Dariusz Jędrzejczyk + */ +public class McpStatelessAsyncServer { + + private static final Logger logger = LoggerFactory.getLogger(McpStatelessAsyncServer.class); + + private final McpStatelessServerTransport mcpTransportProvider; + + private final McpJsonMapper jsonMapper; + + private final McpSchema.ServerCapabilities serverCapabilities; + + private final McpSchema.Implementation serverInfo; + + private final String instructions; + + private final CopyOnWriteArrayList tools = new CopyOnWriteArrayList<>(); + + private final ConcurrentHashMap resourceTemplates = new ConcurrentHashMap<>(); + + private final ConcurrentHashMap resources = new ConcurrentHashMap<>(); + + private final ConcurrentHashMap prompts = new ConcurrentHashMap<>(); + + private final ConcurrentHashMap completions = new ConcurrentHashMap<>(); + + private List protocolVersions; + + private McpUriTemplateManagerFactory uriTemplateManagerFactory = new DefaultMcpUriTemplateManagerFactory(); + + private final JsonSchemaValidator jsonSchemaValidator; + + McpStatelessAsyncServer(McpStatelessServerTransport mcpTransport, McpJsonMapper jsonMapper, + McpStatelessServerFeatures.Async features, Duration requestTimeout, + McpUriTemplateManagerFactory uriTemplateManagerFactory, JsonSchemaValidator jsonSchemaValidator) { + this.mcpTransportProvider = mcpTransport; + this.jsonMapper = jsonMapper; + this.serverInfo = features.serverInfo(); + this.serverCapabilities = features.serverCapabilities(); + this.instructions = features.instructions(); + this.tools.addAll(withStructuredOutputHandling(jsonSchemaValidator, features.tools())); + this.resources.putAll(features.resources()); + this.resourceTemplates.putAll(features.resourceTemplates()); + this.prompts.putAll(features.prompts()); + this.completions.putAll(features.completions()); + this.uriTemplateManagerFactory = uriTemplateManagerFactory; + this.jsonSchemaValidator = jsonSchemaValidator; + + Map> requestHandlers = new HashMap<>(); + + // Initialize request handlers for standard MCP methods + + // Ping MUST respond with an empty data, but not NULL response. + requestHandlers.put(McpSchema.METHOD_PING, (ctx, params) -> Mono.just(Map.of())); + + requestHandlers.put(McpSchema.METHOD_INITIALIZE, asyncInitializeRequestHandler()); + + // Add tools API handlers if the tool capability is enabled + if (this.serverCapabilities.tools() != null) { + requestHandlers.put(McpSchema.METHOD_TOOLS_LIST, toolsListRequestHandler()); + requestHandlers.put(McpSchema.METHOD_TOOLS_CALL, toolsCallRequestHandler()); + } + + // Add resources API handlers if provided + if (this.serverCapabilities.resources() != null) { + requestHandlers.put(McpSchema.METHOD_RESOURCES_LIST, resourcesListRequestHandler()); + requestHandlers.put(McpSchema.METHOD_RESOURCES_READ, resourcesReadRequestHandler()); + requestHandlers.put(McpSchema.METHOD_RESOURCES_TEMPLATES_LIST, resourceTemplateListRequestHandler()); + } + + // Add prompts API handlers if provider exists + if (this.serverCapabilities.prompts() != null) { + requestHandlers.put(McpSchema.METHOD_PROMPT_LIST, promptsListRequestHandler()); + requestHandlers.put(McpSchema.METHOD_PROMPT_GET, promptsGetRequestHandler()); + } + + // Add completion API handlers if the completion capability is enabled + if (this.serverCapabilities.completions() != null) { + requestHandlers.put(McpSchema.METHOD_COMPLETION_COMPLETE, completionCompleteRequestHandler()); + } + + this.protocolVersions = new ArrayList<>(mcpTransport.protocolVersions()); + + McpStatelessServerHandler handler = new DefaultMcpStatelessServerHandler(requestHandlers, Map.of()); + mcpTransport.setMcpHandler(handler); + } + + // --------------------------------------- + // Lifecycle Management + // --------------------------------------- + private McpStatelessRequestHandler asyncInitializeRequestHandler() { + return (ctx, req) -> Mono.defer(() -> { + McpSchema.InitializeRequest initializeRequest = this.jsonMapper.convertValue(req, + McpSchema.InitializeRequest.class); + + logger.info("Client initialize request - Protocol: {}, Capabilities: {}, Info: {}", + initializeRequest.protocolVersion(), initializeRequest.capabilities(), + initializeRequest.clientInfo()); + + // The server MUST respond with the highest protocol version it supports + // if + // it does not support the requested (e.g. Client) version. + String serverProtocolVersion = this.protocolVersions.get(this.protocolVersions.size() - 1); + + if (this.protocolVersions.contains(initializeRequest.protocolVersion())) { + // If the server supports the requested protocol version, it MUST + // respond + // with the same version. + serverProtocolVersion = initializeRequest.protocolVersion(); + } + else { + logger.warn( + "Client requested unsupported protocol version: {}, so the server will suggest the {} version instead", + initializeRequest.protocolVersion(), serverProtocolVersion); + } + + return Mono.just(new McpSchema.InitializeResult(serverProtocolVersion, this.serverCapabilities, + this.serverInfo, this.instructions)); + }); + } + + /** + * Get the server capabilities that define the supported features and functionality. + * @return The server capabilities + */ + public McpSchema.ServerCapabilities getServerCapabilities() { + return this.serverCapabilities; + } + + /** + * Get the server implementation information. + * @return The server implementation details + */ + public McpSchema.Implementation getServerInfo() { + return this.serverInfo; + } + + /** + * Gracefully closes the server, allowing any in-progress operations to complete. + * @return A Mono that completes when the server has been closed + */ + public Mono closeGracefully() { + return this.mcpTransportProvider.closeGracefully(); + } + + /** + * Close the server immediately. + */ + public void close() { + this.mcpTransportProvider.close(); + } + + // --------------------------------------- + // Tool Management + // --------------------------------------- + + private static List withStructuredOutputHandling( + JsonSchemaValidator jsonSchemaValidator, List tools) { + + if (Utils.isEmpty(tools)) { + return tools; + } + + return tools.stream().map(tool -> withStructuredOutputHandling(jsonSchemaValidator, tool)).toList(); + } + + private static McpStatelessServerFeatures.AsyncToolSpecification withStructuredOutputHandling( + JsonSchemaValidator jsonSchemaValidator, + McpStatelessServerFeatures.AsyncToolSpecification toolSpecification) { + + if (toolSpecification.callHandler() instanceof StructuredOutputCallToolHandler) { + // If the tool is already wrapped, return it as is + return toolSpecification; + } + + if (toolSpecification.tool().outputSchema() == null) { + // If the tool does not have an output schema, return it as is + return toolSpecification; + } + + return new McpStatelessServerFeatures.AsyncToolSpecification(toolSpecification.tool(), + new StructuredOutputCallToolHandler(jsonSchemaValidator, toolSpecification.tool().outputSchema(), + toolSpecification.callHandler())); + } + + private static class StructuredOutputCallToolHandler + implements BiFunction> { + + private final BiFunction> delegateHandler; + + private final JsonSchemaValidator jsonSchemaValidator; + + private final Map outputSchema; + + public StructuredOutputCallToolHandler(JsonSchemaValidator jsonSchemaValidator, + Map outputSchema, + BiFunction> delegateHandler) { + + Assert.notNull(jsonSchemaValidator, "JsonSchemaValidator must not be null"); + Assert.notNull(delegateHandler, "Delegate call tool result handler must not be null"); + + this.delegateHandler = delegateHandler; + this.outputSchema = outputSchema; + this.jsonSchemaValidator = jsonSchemaValidator; + } + + @Override + public Mono apply(McpTransportContext transportContext, McpSchema.CallToolRequest request) { + + return this.delegateHandler.apply(transportContext, request).map(result -> { + + if (Boolean.TRUE.equals(result.isError())) { + // If the tool call resulted in an error, skip further validation + return result; + } + + if (outputSchema == null) { + if (result.structuredContent() != null) { + logger.warn( + "Tool call with no outputSchema is not expected to have a result with structured content, but got: {}", + result.structuredContent()); + } + // Pass through. No validation is required if no output schema is + // provided. + return result; + } + + // If an output schema is provided, servers MUST provide structured + // results that conform to this schema. + // https://modelcontextprotocol.io/specification/2025-06-18/server/tools#output-schema + if (result.structuredContent() == null) { + String content = "Response missing structured content which is expected when calling tool with non-empty outputSchema"; + logger.warn(content); + return CallToolResult.builder() + .content(List.of(new McpSchema.TextContent(content))) + .isError(true) + .build(); + } + + // Validate the result against the output schema + var validation = this.jsonSchemaValidator.validate(outputSchema, result.structuredContent()); + + if (!validation.valid()) { + logger.warn("Tool call result validation failed: {}", validation.errorMessage()); + return CallToolResult.builder() + .content(List.of(new McpSchema.TextContent(validation.errorMessage()))) + .isError(true) + .build(); + } + + if (Utils.isEmpty(result.content())) { + // For backwards compatibility, a tool that returns structured + // content SHOULD also return functionally equivalent unstructured + // content. (For example, serialized JSON can be returned in a + // TextContent block.) + // https://modelcontextprotocol.io/specification/2025-06-18/server/tools#structured-content + + return CallToolResult.builder() + .content(List.of(new McpSchema.TextContent(validation.jsonStructuredOutput()))) + .isError(result.isError()) + .structuredContent(result.structuredContent()) + .build(); + } + + return result; + }); + } + + } + + /** + * Add a new tool specification at runtime. + * @param toolSpecification The tool specification to add + * @return Mono that completes when clients have been notified of the change + */ + public Mono addTool(McpStatelessServerFeatures.AsyncToolSpecification toolSpecification) { + if (toolSpecification == null) { + return Mono.error(new IllegalArgumentException("Tool specification must not be null")); + } + if (toolSpecification.tool() == null) { + return Mono.error(new IllegalArgumentException("Tool must not be null")); + } + if (toolSpecification.callHandler() == null) { + return Mono.error(new IllegalArgumentException("Tool call handler must not be null")); + } + if (this.serverCapabilities.tools() == null) { + return Mono.error(new IllegalStateException("Server must be configured with tool capabilities")); + } + + var wrappedToolSpecification = withStructuredOutputHandling(this.jsonSchemaValidator, toolSpecification); + + return Mono.defer(() -> { + // Remove tools with duplicate tool names first + if (this.tools.removeIf(th -> th.tool().name().equals(wrappedToolSpecification.tool().name()))) { + logger.warn("Replace existing Tool with name '{}'", wrappedToolSpecification.tool().name()); + } + + this.tools.add(wrappedToolSpecification); + logger.debug("Added tool handler: {}", wrappedToolSpecification.tool().name()); + + return Mono.empty(); + }); + } + + /** + * List all registered tools. + * @return A Flux stream of all registered tools + */ + public Flux listTools() { + return Flux.fromIterable(this.tools).map(McpStatelessServerFeatures.AsyncToolSpecification::tool); + } + + /** + * Remove a tool handler at runtime. + * @param toolName The name of the tool handler to remove + * @return Mono that completes when clients have been notified of the change + */ + public Mono removeTool(String toolName) { + if (toolName == null) { + return Mono.error(new IllegalArgumentException("Tool name must not be null")); + } + if (this.serverCapabilities.tools() == null) { + return Mono.error(new IllegalStateException("Server must be configured with tool capabilities")); + } + + return Mono.defer(() -> { + if (this.tools.removeIf(toolSpecification -> toolSpecification.tool().name().equals(toolName))) { + + logger.debug("Removed tool handler: {}", toolName); + } + else { + logger.warn("Ignore as a Tool with name '{}' not found", toolName); + } + + return Mono.empty(); + }); + } + + private McpStatelessRequestHandler toolsListRequestHandler() { + return (ctx, params) -> { + List tools = this.tools.stream() + .map(McpStatelessServerFeatures.AsyncToolSpecification::tool) + .toList(); + return Mono.just(new McpSchema.ListToolsResult(tools, null)); + }; + } + + private McpStatelessRequestHandler toolsCallRequestHandler() { + return (ctx, params) -> { + McpSchema.CallToolRequest callToolRequest = jsonMapper.convertValue(params, + new TypeRef() { + }); + + Optional toolSpecification = this.tools.stream() + .filter(tr -> callToolRequest.name().equals(tr.tool().name())) + .findAny(); + + if (toolSpecification.isEmpty()) { + return Mono.error(McpError.builder(McpSchema.ErrorCodes.INVALID_PARAMS) + .message("Unknown tool: invalid_tool_name") + .data("Tool not found: " + callToolRequest.name()) + .build()); + } + + return toolSpecification.get().callHandler().apply(ctx, callToolRequest); + }; + } + + // --------------------------------------- + // Resource Management + // --------------------------------------- + + /** + * Add a new resource handler at runtime. + * @param resourceSpecification The resource handler to add + * @return Mono that completes when clients have been notified of the change + */ + public Mono addResource(McpStatelessServerFeatures.AsyncResourceSpecification resourceSpecification) { + if (resourceSpecification == null || resourceSpecification.resource() == null) { + return Mono.error(new IllegalArgumentException("Resource must not be null")); + } + + if (this.serverCapabilities.resources() == null) { + return Mono.error(new IllegalStateException("Server must be configured with resource capabilities")); + } + + return Mono.defer(() -> { + var previous = this.resources.put(resourceSpecification.resource().uri(), resourceSpecification); + if (previous != null) { + logger.warn("Replace existing Resource with URI '{}'", resourceSpecification.resource().uri()); + } + else { + logger.debug("Added resource handler: {}", resourceSpecification.resource().uri()); + } + return Mono.empty(); + }); + } + + /** + * List all registered resources. + * @return A Flux stream of all registered resources + */ + public Flux listResources() { + return Flux.fromIterable(this.resources.values()) + .map(McpStatelessServerFeatures.AsyncResourceSpecification::resource); + } + + /** + * Remove a resource handler at runtime. + * @param resourceUri The URI of the resource handler to remove + * @return Mono that completes when clients have been notified of the change + */ + public Mono removeResource(String resourceUri) { + if (resourceUri == null) { + return Mono.error(new IllegalArgumentException("Resource URI must not be null")); + } + if (this.serverCapabilities.resources() == null) { + return Mono.error(new IllegalStateException("Server must be configured with resource capabilities")); + } + + return Mono.defer(() -> { + McpStatelessServerFeatures.AsyncResourceSpecification removed = this.resources.remove(resourceUri); + if (removed != null) { + logger.debug("Removed resource handler: {}", resourceUri); + } + else { + logger.warn("Resource with URI '{}' not found", resourceUri); + } + return Mono.empty(); + }); + } + + /** + * Add a new resource template at runtime. + * @param resourceTemplateSpecification The resource template to add + * @return Mono that completes when clients have been notified of the change + */ + public Mono addResourceTemplate( + McpStatelessServerFeatures.AsyncResourceTemplateSpecification resourceTemplateSpecification) { + + if (this.serverCapabilities.resources() == null) { + return Mono.error(new IllegalStateException( + "Server must be configured with resource capabilities to allow adding resource templates")); + } + + return Mono.defer(() -> { + var previous = this.resourceTemplates.put(resourceTemplateSpecification.resourceTemplate().uriTemplate(), + resourceTemplateSpecification); + if (previous != null) { + logger.warn("Replace existing Resource Template with URI '{}'", + resourceTemplateSpecification.resourceTemplate().uriTemplate()); + } + else { + logger.debug("Added resource template handler: {}", + resourceTemplateSpecification.resourceTemplate().uriTemplate()); + } + return Mono.empty(); + }); + } + + /** + * List all registered resource templates. + * @return A Flux stream of all registered resource templates + */ + public Flux listResourceTemplates() { + return Flux.fromIterable(this.resourceTemplates.values()) + .map(McpStatelessServerFeatures.AsyncResourceTemplateSpecification::resourceTemplate); + } + + /** + * Remove a resource template at runtime. + * @param uriTemplate The URI template of the resource template to remove + * @return Mono that completes when clients have been notified of the change + */ + public Mono removeResourceTemplate(String uriTemplate) { + + if (this.serverCapabilities.resources() == null) { + return Mono.error(new IllegalStateException( + "Server must be configured with resource capabilities to allow removing resource templates")); + } + + return Mono.defer(() -> { + McpStatelessServerFeatures.AsyncResourceTemplateSpecification removed = this.resourceTemplates + .remove(uriTemplate); + if (removed != null) { + logger.debug("Removed resource template: {}", uriTemplate); + } + else { + logger.warn("Ignore as a Resource Template with URI '{}' not found", uriTemplate); + } + return Mono.empty(); + }); + } + + private McpStatelessRequestHandler resourcesListRequestHandler() { + return (ctx, params) -> { + var resourceList = this.resources.values() + .stream() + .map(McpStatelessServerFeatures.AsyncResourceSpecification::resource) + .toList(); + return Mono.just(new McpSchema.ListResourcesResult(resourceList, null)); + }; + } + + private McpStatelessRequestHandler resourceTemplateListRequestHandler() { + return (exchange, params) -> { + var resourceList = this.resourceTemplates.values() + .stream() + .map(AsyncResourceTemplateSpecification::resourceTemplate) + .toList(); + return Mono.just(new McpSchema.ListResourceTemplatesResult(resourceList, null)); + }; + } + + private McpStatelessRequestHandler resourcesReadRequestHandler() { + return (ctx, params) -> { + McpSchema.ReadResourceRequest resourceRequest = jsonMapper.convertValue(params, new TypeRef<>() { + }); + var resourceUri = resourceRequest.uri(); + + // First try to find a static resource specification + // Static resources have exact URIs + return this.findResourceSpecification(resourceUri) + .map(spec -> spec.readHandler().apply(ctx, resourceRequest)) + .orElseGet(() -> { + // If not found, try to find a dynamic resource specification + // Dynamic resources have URI templates + return this.findResourceTemplateSpecification(resourceUri) + .map(spec -> spec.readHandler().apply(ctx, resourceRequest)) + .orElseGet(() -> Mono.error(RESOURCE_NOT_FOUND.apply(resourceUri))); + }); + + }; + } + + private Optional findResourceSpecification(String uri) { + var result = this.resources.values() + .stream() + .filter(spec -> this.uriTemplateManagerFactory.create(spec.resource().uri()).matches(uri)) + .findFirst(); + return result; + } + + private Optional findResourceTemplateSpecification( + String uri) { + return this.resourceTemplates.values() + .stream() + .filter(spec -> this.uriTemplateManagerFactory.create(spec.resourceTemplate().uriTemplate()).matches(uri)) + .findFirst(); + } + + // --------------------------------------- + // Prompt Management + // --------------------------------------- + + /** + * Add a new prompt handler at runtime. + * @param promptSpecification The prompt handler to add + * @return Mono that completes when clients have been notified of the change + */ + public Mono addPrompt(McpStatelessServerFeatures.AsyncPromptSpecification promptSpecification) { + if (promptSpecification == null) { + return Mono.error(new IllegalArgumentException("Prompt specification must not be null")); + } + if (this.serverCapabilities.prompts() == null) { + return Mono.error(new IllegalStateException("Server must be configured with prompt capabilities")); + } + + return Mono.defer(() -> { + var previous = this.prompts.put(promptSpecification.prompt().name(), promptSpecification); + if (previous != null) { + logger.warn("Replace existing Prompt with name '{}'", promptSpecification.prompt().name()); + } + else { + logger.debug("Added prompt handler: {}", promptSpecification.prompt().name()); + } + + return Mono.empty(); + }); + } + + /** + * List all registered prompts. + * @return A Flux stream of all registered prompts + */ + public Flux listPrompts() { + return Flux.fromIterable(this.prompts.values()) + .map(McpStatelessServerFeatures.AsyncPromptSpecification::prompt); + } + + /** + * Remove a prompt handler at runtime. + * @param promptName The name of the prompt handler to remove + * @return Mono that completes when clients have been notified of the change + */ + public Mono removePrompt(String promptName) { + if (promptName == null) { + return Mono.error(new IllegalArgumentException("Prompt name must not be null")); + } + if (this.serverCapabilities.prompts() == null) { + return Mono.error(new IllegalStateException("Server must be configured with prompt capabilities")); + } + + return Mono.defer(() -> { + McpStatelessServerFeatures.AsyncPromptSpecification removed = this.prompts.remove(promptName); + + if (removed != null) { + logger.debug("Removed prompt handler: {}", promptName); + return Mono.empty(); + } + else { + logger.warn("Ignore as a Prompt with name '{}' not found", promptName); + } + + return Mono.empty(); + }); + } + + private McpStatelessRequestHandler promptsListRequestHandler() { + return (ctx, params) -> { + // TODO: Implement pagination + // McpSchema.PaginatedRequest request = objectMapper.convertValue(params, + // new TypeReference() { + // }); + + var promptList = this.prompts.values() + .stream() + .map(McpStatelessServerFeatures.AsyncPromptSpecification::prompt) + .toList(); + + return Mono.just(new McpSchema.ListPromptsResult(promptList, null)); + }; + } + + private McpStatelessRequestHandler promptsGetRequestHandler() { + return (ctx, params) -> { + McpSchema.GetPromptRequest promptRequest = jsonMapper.convertValue(params, + new TypeRef() { + }); + + // Implement prompt retrieval logic here + McpStatelessServerFeatures.AsyncPromptSpecification specification = this.prompts.get(promptRequest.name()); + if (specification == null) { + return Mono.error(McpError.builder(ErrorCodes.INVALID_PARAMS) + .message("Invalid prompt name") + .data("Prompt not found: " + promptRequest.name()) + .build()); + } + + return specification.promptHandler().apply(ctx, promptRequest); + }; + } + + private static final Mono EMPTY_COMPLETION_RESULT = Mono + .just(new McpSchema.CompleteResult(new CompleteCompletion(List.of(), 0, false))); + + private McpStatelessRequestHandler completionCompleteRequestHandler() { + return (ctx, params) -> { + McpSchema.CompleteRequest request = parseCompletionParams(params); + + if (request.ref() == null) { + return Mono.error( + McpError.builder(ErrorCodes.INVALID_PARAMS).message("Completion ref must not be null").build()); + } + + if (request.ref().type() == null) { + return Mono.error(McpError.builder(ErrorCodes.INVALID_PARAMS) + .message("Completion ref type must not be null") + .build()); + } + + String type = request.ref().type(); + + String argumentName = request.argument().name(); + + // Check if valid a Prompt exists for this completion request + if (type.equals(PromptReference.TYPE) + && request.ref() instanceof McpSchema.PromptReference promptReference) { + + McpStatelessServerFeatures.AsyncPromptSpecification promptSpec = this.prompts + .get(promptReference.name()); + if (promptSpec == null) { + return Mono.error(McpError.builder(ErrorCodes.INVALID_PARAMS) + .message("Prompt not found: " + promptReference.name()) + .build()); + } + if (!promptSpec.prompt() + .arguments() + .stream() + .filter(arg -> arg.name().equals(argumentName)) + .findFirst() + .isPresent()) { + + logger.warn("Argument not found: {} in prompt: {}", argumentName, promptReference.name()); + + return EMPTY_COMPLETION_RESULT; + } + } + + // Check if valid Resource or ResourceTemplate exists for this completion + // request + if (type.equals(ResourceReference.TYPE) + && request.ref() instanceof McpSchema.ResourceReference resourceReference) { + + var uriTemplateManager = uriTemplateManagerFactory.create(resourceReference.uri()); + + if (!uriTemplateManager.isUriTemplate(resourceReference.uri())) { + // Attempting to autocomplete a fixed resource URI is not an error in + // the spec (but probably should be). + return EMPTY_COMPLETION_RESULT; + } + + McpStatelessServerFeatures.AsyncResourceSpecification resourceSpec = this + .findResourceSpecification(resourceReference.uri()) + .orElse(null); + + if (resourceSpec != null) { + if (!uriTemplateManagerFactory.create(resourceSpec.resource().uri()) + .getVariableNames() + .contains(argumentName)) { + + return Mono.error(McpError.builder(ErrorCodes.INVALID_PARAMS) + .message("Argument not found: " + argumentName + " in resource: " + resourceReference.uri()) + .build()); + } + } + else { + var templateSpec = this.findResourceTemplateSpecification(resourceReference.uri()).orElse(null); + if (templateSpec != null) { + + if (!uriTemplateManagerFactory.create(templateSpec.resourceTemplate().uriTemplate()) + .getVariableNames() + .contains(argumentName)) { + + return Mono.error(McpError.builder(ErrorCodes.INVALID_PARAMS) + .message("Argument not found: " + argumentName + " in resource template: " + + resourceReference.uri()) + .build()); + } + } + else { + return Mono.error(RESOURCE_NOT_FOUND.apply(resourceReference.uri())); + } + } + } + + McpStatelessServerFeatures.AsyncCompletionSpecification specification = this.completions.get(request.ref()); + + if (specification == null) { + return Mono.error(McpError.builder(ErrorCodes.INVALID_PARAMS) + .message("AsyncCompletionSpecification not found: " + request.ref()) + .build()); + } + + return specification.completionHandler().apply(ctx, request); + }; + } + + /** + * Parses the raw JSON-RPC request parameters into a {@link McpSchema.CompleteRequest} + * object. + *

+ * This method manually extracts the `ref` and `argument` fields from the input map, + * determines the correct reference type (either prompt or resource), and constructs a + * fully-typed {@code CompleteRequest} instance. + * @param object the raw request parameters, expected to be a Map containing "ref" and + * "argument" entries. + * @return a {@link McpSchema.CompleteRequest} representing the structured completion + * request. + * @throws IllegalArgumentException if the "ref" type is not recognized. + */ + @SuppressWarnings("unchecked") + private McpSchema.CompleteRequest parseCompletionParams(Object object) { + Map params = (Map) object; + Map refMap = (Map) params.get("ref"); + Map argMap = (Map) params.get("argument"); + + String refType = (String) refMap.get("type"); + + McpSchema.CompleteReference ref = switch (refType) { + case PromptReference.TYPE -> new McpSchema.PromptReference(refType, (String) refMap.get("name"), + refMap.get("title") != null ? (String) refMap.get("title") : null); + case ResourceReference.TYPE -> new McpSchema.ResourceReference(refType, (String) refMap.get("uri")); + default -> throw new IllegalArgumentException("Invalid ref type: " + refType); + }; + + String argName = (String) argMap.get("name"); + String argValue = (String) argMap.get("value"); + McpSchema.CompleteRequest.CompleteArgument argument = new McpSchema.CompleteRequest.CompleteArgument(argName, + argValue); + + return new McpSchema.CompleteRequest(ref, argument); + } + + /** + * This method is package-private and used for test only. Should not be called by user + * code. + * @param protocolVersions the Client supported protocol versions. + */ + void setProtocolVersions(List protocolVersions) { + this.protocolVersions = protocolVersions; + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessNotificationHandler.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessNotificationHandler.java new file mode 100644 index 000000000..a2fabb283 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessNotificationHandler.java @@ -0,0 +1,25 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import io.modelcontextprotocol.common.McpTransportContext; +import reactor.core.publisher.Mono; + +/** + * Handler for MCP notifications in a stateless server. + * + * @author Dariusz Jędrzejczyk + */ +public interface McpStatelessNotificationHandler { + + /** + * Handle to notification and complete once done. + * @param transportContext {@link McpTransportContext} associated with the transport + * @param params the payload of the MCP notification + * @return Mono which completes once the processing is done + */ + Mono handle(McpTransportContext transportContext, Object params); + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessRequestHandler.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessRequestHandler.java new file mode 100644 index 000000000..37cd3c096 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessRequestHandler.java @@ -0,0 +1,26 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import io.modelcontextprotocol.common.McpTransportContext; +import reactor.core.publisher.Mono; + +/** + * Handler for MCP requests in a stateless server. + * + * @param type of the MCP response + * @author Dariusz Jędrzejczyk + */ +public interface McpStatelessRequestHandler { + + /** + * Handle the request and complete with a result. + * @param transportContext {@link McpTransportContext} associated with the transport + * @param params the payload of the MCP request + * @return Mono which completes with the response object + */ + Mono handle(McpTransportContext transportContext, Object params); + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessServerFeatures.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessServerFeatures.java new file mode 100644 index 000000000..a15681ba5 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessServerFeatures.java @@ -0,0 +1,555 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.BiFunction; + +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; +import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.Utils; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; + +/** + * MCP stateless server features specification that a particular server can choose to + * support. + * + * @author Dariusz Jędrzejczyk + * @author Christian Tzolov + */ +public class McpStatelessServerFeatures { + + /** + * Asynchronous server features specification. + * + * @param serverInfo The server implementation details + * @param serverCapabilities The server capabilities + * @param tools The list of tool specifications + * @param resources The map of resource specifications + * @param resourceTemplates The map of resource templates + * @param prompts The map of prompt specifications + * @param instructions The server instructions text + */ + record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, + List tools, + Map resources, + Map resourceTemplates, + Map prompts, + Map completions, + String instructions) { + + /** + * Create an instance and validate the arguments. + * @param serverInfo The server implementation details + * @param serverCapabilities The server capabilities + * @param tools The list of tool specifications + * @param resources The map of resource specifications + * @param resourceTemplates The map of resource templates + * @param prompts The map of prompt specifications + * @param instructions The server instructions text + */ + Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, + List tools, + Map resources, + Map resourceTemplates, + Map prompts, + Map completions, + String instructions) { + + Assert.notNull(serverInfo, "Server info must not be null"); + + this.serverInfo = serverInfo; + this.serverCapabilities = (serverCapabilities != null) ? serverCapabilities + : new McpSchema.ServerCapabilities(null, // completions + null, // experimental + null, // currently statless server doesn't support set logging + !Utils.isEmpty(prompts) ? new McpSchema.ServerCapabilities.PromptCapabilities(false) : null, + !Utils.isEmpty(resources) + ? new McpSchema.ServerCapabilities.ResourceCapabilities(false, false) : null, + !Utils.isEmpty(tools) ? new McpSchema.ServerCapabilities.ToolCapabilities(false) : null); + + this.tools = (tools != null) ? tools : List.of(); + this.resources = (resources != null) ? resources : Map.of(); + this.resourceTemplates = (resourceTemplates != null) ? resourceTemplates : Map.of(); + this.prompts = (prompts != null) ? prompts : Map.of(); + this.completions = (completions != null) ? completions : Map.of(); + this.instructions = instructions; + } + + /** + * Convert a synchronous specification into an asynchronous one and provide + * blocking code offloading to prevent accidental blocking of the non-blocking + * transport. + * @param syncSpec a potentially blocking, synchronous specification. + * @param immediateExecution when true, do not offload. Do NOT set to true when + * using a non-blocking transport. + * @return a specification which is protected from blocking calls specified by the + * user. + */ + static Async fromSync(Sync syncSpec, boolean immediateExecution) { + List tools = new ArrayList<>(); + for (var tool : syncSpec.tools()) { + tools.add(AsyncToolSpecification.fromSync(tool, immediateExecution)); + } + + Map resources = new HashMap<>(); + syncSpec.resources().forEach((key, resource) -> { + resources.put(key, AsyncResourceSpecification.fromSync(resource, immediateExecution)); + }); + + Map resourceTemplates = new HashMap<>(); + syncSpec.resourceTemplates().forEach((key, resource) -> { + resourceTemplates.put(key, AsyncResourceTemplateSpecification.fromSync(resource, immediateExecution)); + }); + + Map prompts = new HashMap<>(); + syncSpec.prompts().forEach((key, prompt) -> { + prompts.put(key, AsyncPromptSpecification.fromSync(prompt, immediateExecution)); + }); + + Map completions = new HashMap<>(); + syncSpec.completions().forEach((key, completion) -> { + completions.put(key, AsyncCompletionSpecification.fromSync(completion, immediateExecution)); + }); + + return new Async(syncSpec.serverInfo(), syncSpec.serverCapabilities(), tools, resources, resourceTemplates, + prompts, completions, syncSpec.instructions()); + } + } + + /** + * Synchronous server features specification. + * + * @param serverInfo The server implementation details + * @param serverCapabilities The server capabilities + * @param tools The list of tool specifications + * @param resources The map of resource specifications + * @param resourceTemplates The map of resource templates + * @param prompts The map of prompt specifications + * @param instructions The server instructions text + */ + record Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, + List tools, + Map resources, + Map resourceTemplates, + Map prompts, + Map completions, + String instructions) { + + /** + * Create an instance and validate the arguments. + * @param serverInfo The server implementation details + * @param serverCapabilities The server capabilities + * @param tools The list of tool specifications + * @param resources The map of resource specifications + * @param resourceTemplates The map of resource templates + * @param prompts The map of prompt specifications + * @param instructions The server instructions text + */ + Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, + List tools, + Map resources, + Map resourceTemplates, + Map prompts, + Map completions, + String instructions) { + + Assert.notNull(serverInfo, "Server info must not be null"); + + this.serverInfo = serverInfo; + this.serverCapabilities = (serverCapabilities != null) ? serverCapabilities + : new McpSchema.ServerCapabilities(null, // completions + null, // experimental + new McpSchema.ServerCapabilities.LoggingCapabilities(), // Enable + // logging + // by + // default + !Utils.isEmpty(prompts) ? new McpSchema.ServerCapabilities.PromptCapabilities(false) : null, + !Utils.isEmpty(resources) + ? new McpSchema.ServerCapabilities.ResourceCapabilities(false, false) : null, + !Utils.isEmpty(tools) ? new McpSchema.ServerCapabilities.ToolCapabilities(false) : null); + + this.tools = (tools != null) ? tools : new ArrayList<>(); + this.resources = (resources != null) ? resources : new HashMap<>(); + this.resourceTemplates = (resourceTemplates != null) ? resourceTemplates : Map.of(); + this.prompts = (prompts != null) ? prompts : new HashMap<>(); + this.completions = (completions != null) ? completions : new HashMap<>(); + this.instructions = instructions; + } + + } + + /** + * Specification of a tool with its asynchronous handler function. Tools are the + * primary way for MCP servers to expose functionality to AI models. Each tool + * represents a specific capability. + * + * @param tool The tool definition including name, description, and parameter schema + * @param callHandler The function that implements the tool's logic, receiving a + * {@link CallToolRequest} and returning the result. + */ + public record AsyncToolSpecification(McpSchema.Tool tool, + BiFunction> callHandler) { + + static AsyncToolSpecification fromSync(SyncToolSpecification syncToolSpec) { + return fromSync(syncToolSpec, false); + } + + static AsyncToolSpecification fromSync(SyncToolSpecification syncToolSpec, boolean immediate) { + + // FIXME: This is temporary, proper validation should be implemented + if (syncToolSpec == null) { + return null; + } + + BiFunction> callHandler = (ctx, + req) -> { + var toolResult = Mono.fromCallable(() -> syncToolSpec.callHandler().apply(ctx, req)); + return immediate ? toolResult : toolResult.subscribeOn(Schedulers.boundedElastic()); + }; + + return new AsyncToolSpecification(syncToolSpec.tool(), callHandler); + } + + /** + * Builder for creating AsyncToolSpecification instances. + */ + public static class Builder { + + private McpSchema.Tool tool; + + private BiFunction> callHandler; + + /** + * Sets the tool definition. + * @param tool The tool definition including name, description, and parameter + * schema + * @return this builder instance + */ + public Builder tool(McpSchema.Tool tool) { + this.tool = tool; + return this; + } + + /** + * Sets the call tool handler function. + * @param callHandler The function that implements the tool's logic + * @return this builder instance + */ + public Builder callHandler( + BiFunction> callHandler) { + this.callHandler = callHandler; + return this; + } + + /** + * Builds the AsyncToolSpecification instance. + * @return a new AsyncToolSpecification instance + * @throws IllegalArgumentException if required fields are not set + */ + public AsyncToolSpecification build() { + Assert.notNull(tool, "Tool must not be null"); + Assert.notNull(callHandler, "Call handler function must not be null"); + + return new AsyncToolSpecification(tool, callHandler); + } + + } + + /** + * Creates a new builder instance. + * @return a new Builder instance + */ + public static Builder builder() { + return new Builder(); + } + } + + /** + * Specification of a resource with its asynchronous handler function. Resources + * provide context to AI models by exposing data such as: + *

    + *
  • File contents + *
  • Database records + *
  • API responses + *
  • System information + *
  • Application state + *
+ * + * @param resource The resource definition including name, description, and MIME type + * @param readHandler The function that handles resource read requests. The function's + * argument is a {@link McpSchema.ReadResourceRequest}. + */ + public record AsyncResourceSpecification(McpSchema.Resource resource, + BiFunction> readHandler) { + + static AsyncResourceSpecification fromSync(SyncResourceSpecification resource, boolean immediateExecution) { + // FIXME: This is temporary, proper validation should be implemented + if (resource == null) { + return null; + } + return new AsyncResourceSpecification(resource.resource(), (ctx, req) -> { + var resourceResult = Mono.fromCallable(() -> resource.readHandler().apply(ctx, req)); + return immediateExecution ? resourceResult : resourceResult.subscribeOn(Schedulers.boundedElastic()); + }); + } + } + + /** + * Specification of a resource template with its synchronous handler function. + * Resource templates allow servers to expose parameterized resources using URI + * templates: URI + * templates.. Arguments may be auto-completed through the + * completion API. + * + * Templates support: + *
    + *
  • Parameterized resource definitions + *
  • Dynamic content generation + *
  • Consistent resource formatting + *
  • Contextual data injection + *
+ * + * @param resourceTemplate The resource template definition including name, + * description, and parameter schema + * @param readHandler The function that handles resource read requests. The function's + * first argument is an {@link McpTransportContext} upon which the server can interact + * with the connected client. The second arguments is a + * {@link McpSchema.ReadResourceRequest}. {@link McpSchema.ResourceTemplate} + * {@link McpSchema.ReadResourceResult} + */ + public record AsyncResourceTemplateSpecification(McpSchema.ResourceTemplate resourceTemplate, + BiFunction> readHandler) { + + static AsyncResourceTemplateSpecification fromSync(SyncResourceTemplateSpecification resource, + boolean immediateExecution) { + // FIXME: This is temporary, proper validation should be implemented + if (resource == null) { + return null; + } + return new AsyncResourceTemplateSpecification(resource.resourceTemplate(), (ctx, req) -> { + var resourceResult = Mono.fromCallable(() -> resource.readHandler().apply(ctx, req)); + return immediateExecution ? resourceResult : resourceResult.subscribeOn(Schedulers.boundedElastic()); + }); + } + } + + /** + * Specification of a prompt template with its asynchronous handler function. Prompts + * provide structured templates for AI model interactions, supporting: + *
    + *
  • Consistent message formatting + *
  • Parameter substitution + *
  • Context injection + *
  • Response formatting + *
  • Instruction templating + *
+ * + * @param prompt The prompt definition including name and description + * @param promptHandler The function that processes prompt requests and returns + * formatted templates. The function's argument is a + * {@link McpSchema.GetPromptRequest}. + */ + public record AsyncPromptSpecification(McpSchema.Prompt prompt, + BiFunction> promptHandler) { + + static AsyncPromptSpecification fromSync(SyncPromptSpecification prompt, boolean immediateExecution) { + // FIXME: This is temporary, proper validation should be implemented + if (prompt == null) { + return null; + } + return new AsyncPromptSpecification(prompt.prompt(), (ctx, req) -> { + var promptResult = Mono.fromCallable(() -> prompt.promptHandler().apply(ctx, req)); + return immediateExecution ? promptResult : promptResult.subscribeOn(Schedulers.boundedElastic()); + }); + } + } + + /** + * Specification of a completion handler function with asynchronous execution support. + * Completions generate AI model outputs based on prompt or resource references and + * user-provided arguments. This abstraction enables: + *
    + *
  • Customizable response generation logic + *
  • Parameter-driven template expansion + *
  • Dynamic interaction with connected clients + *
+ * + * @param referenceKey The unique key representing the completion reference. + * @param completionHandler The asynchronous function that processes completion + * requests and returns results. The function's argument is a + * {@link McpSchema.CompleteRequest}. + */ + public record AsyncCompletionSpecification(McpSchema.CompleteReference referenceKey, + BiFunction> completionHandler) { + + /** + * Converts a synchronous {@link SyncCompletionSpecification} into an + * {@link AsyncCompletionSpecification} by wrapping the handler in a bounded + * elastic scheduler for safe non-blocking execution. + * @param completion the synchronous completion specification + * @return an asynchronous wrapper of the provided sync specification, or + * {@code null} if input is null + */ + static AsyncCompletionSpecification fromSync(SyncCompletionSpecification completion, + boolean immediateExecution) { + if (completion == null) { + return null; + } + return new AsyncCompletionSpecification(completion.referenceKey(), (ctx, req) -> { + var completionResult = Mono.fromCallable(() -> completion.completionHandler().apply(ctx, req)); + return immediateExecution ? completionResult + : completionResult.subscribeOn(Schedulers.boundedElastic()); + }); + } + } + + /** + * Specification of a tool with its synchronous handler function. Tools are the + * primary way for MCP servers to expose functionality to AI models. + * + * @param tool The tool definition including name, description, and parameter schema + * @param callHandler The function that implements the tool's logic, receiving a + * {@link CallToolRequest} and returning results. + */ + public record SyncToolSpecification(McpSchema.Tool tool, + BiFunction callHandler) { + + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for creating SyncToolSpecification instances. + */ + public static class Builder { + + private McpSchema.Tool tool; + + private BiFunction callHandler; + + /** + * Sets the tool definition. + * @param tool The tool definition including name, description, and parameter + * schema + * @return this builder instance + */ + public Builder tool(McpSchema.Tool tool) { + this.tool = tool; + return this; + } + + /** + * Sets the call tool handler function. + * @param callHandler The function that implements the tool's logic + * @return this builder instance + */ + public Builder callHandler( + BiFunction callHandler) { + this.callHandler = callHandler; + return this; + } + + /** + * Builds the SyncToolSpecification instance. + * @return a new SyncToolSpecification instance + * @throws IllegalArgumentException if required fields are not set + */ + public SyncToolSpecification build() { + Assert.notNull(tool, "Tool must not be null"); + Assert.notNull(callHandler, "CallTool function must not be null"); + + return new SyncToolSpecification(tool, callHandler); + } + + } + } + + /** + * Specification of a resource with its synchronous handler function. Resources + * provide context to AI models by exposing data such as: + *
    + *
  • File contents + *
  • Database records + *
  • API responses + *
  • System information + *
  • Application state + *
+ * + * @param resource The resource definition including name, description, and MIME type + * @param readHandler The function that handles resource read requests. The function's + * argument is a {@link McpSchema.ReadResourceRequest}. + */ + public record SyncResourceSpecification(McpSchema.Resource resource, + BiFunction readHandler) { + } + + /** + * Specification of a resource template with its synchronous handler function. + * Resource templates allow servers to expose parameterized resources using URI + * templates: URI + * templates.. Arguments may be auto-completed through the + * completion API. + * + * Templates support: + *
    + *
  • Parameterized resource definitions + *
  • Dynamic content generation + *
  • Consistent resource formatting + *
  • Contextual data injection + *
+ * + * @param resourceTemplate The resource template definition including name, + * description, and parameter schema + * @param readHandler The function that handles resource read requests. The function's + * first argument is an {@link McpTransportContext} upon which the server can interact + * with the connected client. The second arguments is a + * {@link McpSchema.ReadResourceRequest}. {@link McpSchema.ResourceTemplate} + * {@link McpSchema.ReadResourceResult} + */ + public record SyncResourceTemplateSpecification(McpSchema.ResourceTemplate resourceTemplate, + BiFunction readHandler) { + } + + /** + * Specification of a prompt template with its synchronous handler function. Prompts + * provide structured templates for AI model interactions, supporting: + *
    + *
  • Consistent message formatting + *
  • Parameter substitution + *
  • Context injection + *
  • Response formatting + *
  • Instruction templating + *
+ * + * @param prompt The prompt definition including name and description + * @param promptHandler The function that processes prompt requests and returns + * formatted templates. The function's argument is a + * {@link McpSchema.GetPromptRequest}. + */ + public record SyncPromptSpecification(McpSchema.Prompt prompt, + BiFunction promptHandler) { + } + + /** + * Specification of a completion handler function with synchronous execution support. + * + * @param referenceKey The unique key representing the completion reference. + * @param completionHandler The synchronous function that processes completion + * requests and returns results. The argument is a {@link McpSchema.CompleteRequest}. + */ + public record SyncCompletionSpecification(McpSchema.CompleteReference referenceKey, + BiFunction completionHandler) { + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessServerHandler.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessServerHandler.java new file mode 100644 index 000000000..cbae58bfd --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessServerHandler.java @@ -0,0 +1,38 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.spec.McpSchema; +import reactor.core.publisher.Mono; + +/** + * Handler for MCP requests and notifications in a Stateless Streamable HTTP Server + * context. + * + * @author Dariusz Jędrzejczyk + */ +public interface McpStatelessServerHandler { + + /** + * Handle the request using user-provided feature implementations. + * @param transportContext {@link McpTransportContext} carrying transport layer + * metadata + * @param request the request JSON object + * @return Mono containing the JSON response + */ + Mono handleRequest(McpTransportContext transportContext, + McpSchema.JSONRPCRequest request); + + /** + * Handle the notification. + * @param transportContext {@link McpTransportContext} carrying transport layer + * metadata + * @param notification the notification JSON object + * @return Mono that completes once handling is finished + */ + Mono handleNotification(McpTransportContext transportContext, McpSchema.JSONRPCNotification notification); + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessSyncServer.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessSyncServer.java new file mode 100644 index 000000000..6849eb8ed --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessSyncServer.java @@ -0,0 +1,184 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import io.modelcontextprotocol.spec.McpSchema; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Mono; + +import java.util.List; + +/** + * A stateless MCP server implementation for use with Streamable HTTP transport types. It + * allows simple horizontal scalability since it does not maintain a session and does not + * require initialization. Each instance of the server can be reached with no prior + * knowledge and can serve the clients with the capabilities it supports. + * + * @author Dariusz Jędrzejczyk + */ +public class McpStatelessSyncServer { + + private static final Logger logger = LoggerFactory.getLogger(McpStatelessSyncServer.class); + + private final McpStatelessAsyncServer asyncServer; + + private final boolean immediateExecution; + + McpStatelessSyncServer(McpStatelessAsyncServer asyncServer, boolean immediateExecution) { + this.asyncServer = asyncServer; + this.immediateExecution = immediateExecution; + } + + /** + * Get the server capabilities that define the supported features and functionality. + * @return The server capabilities + */ + public McpSchema.ServerCapabilities getServerCapabilities() { + return this.asyncServer.getServerCapabilities(); + } + + /** + * Get the server implementation information. + * @return The server implementation details + */ + public McpSchema.Implementation getServerInfo() { + return this.asyncServer.getServerInfo(); + } + + /** + * Gracefully closes the server, allowing any in-progress operations to complete. + * @return A Mono that completes when the server has been closed + */ + public Mono closeGracefully() { + return this.asyncServer.closeGracefully(); + } + + /** + * Close the server immediately. + */ + public void close() { + this.asyncServer.close(); + } + + /** + * Add a new tool specification at runtime. + * @param toolSpecification The tool specification to add + */ + public void addTool(McpStatelessServerFeatures.SyncToolSpecification toolSpecification) { + this.asyncServer + .addTool(McpStatelessServerFeatures.AsyncToolSpecification.fromSync(toolSpecification, + this.immediateExecution)) + .block(); + } + + /** + * List all registered tools. + * @return A list of all registered tools + */ + public List listTools() { + return this.asyncServer.listTools().collectList().block(); + } + + /** + * Remove a tool handler at runtime. + * @param toolName The name of the tool handler to remove + */ + public void removeTool(String toolName) { + this.asyncServer.removeTool(toolName).block(); + } + + /** + * Add a new resource handler at runtime. + * @param resourceSpecification The resource handler to add + */ + public void addResource(McpStatelessServerFeatures.SyncResourceSpecification resourceSpecification) { + this.asyncServer + .addResource(McpStatelessServerFeatures.AsyncResourceSpecification.fromSync(resourceSpecification, + this.immediateExecution)) + .block(); + } + + /** + * List all registered resources. + * @return A list of all registered resources + */ + public List listResources() { + return this.asyncServer.listResources().collectList().block(); + } + + /** + * Remove a resource handler at runtime. + * @param resourceUri The URI of the resource handler to remove + */ + public void removeResource(String resourceUri) { + this.asyncServer.removeResource(resourceUri).block(); + } + + /** + * Add a new resource template. + * @param resourceTemplateSpecification The resource template specification to add + */ + public void addResourceTemplate( + McpStatelessServerFeatures.SyncResourceTemplateSpecification resourceTemplateSpecification) { + this.asyncServer + .addResourceTemplate(McpStatelessServerFeatures.AsyncResourceTemplateSpecification + .fromSync(resourceTemplateSpecification, this.immediateExecution)) + .block(); + } + + /** + * List all registered resource templates. + * @return A list of all registered resource templates + */ + public List listResourceTemplates() { + return this.asyncServer.listResourceTemplates().collectList().block(); + } + + /** + * Remove a resource template. + * @param uriTemplate The URI template of the resource template to remove + */ + public void removeResourceTemplate(String uriTemplate) { + this.asyncServer.removeResourceTemplate(uriTemplate).block(); + } + + /** + * Add a new prompt handler at runtime. + * @param promptSpecification The prompt handler to add + */ + public void addPrompt(McpStatelessServerFeatures.SyncPromptSpecification promptSpecification) { + this.asyncServer + .addPrompt(McpStatelessServerFeatures.AsyncPromptSpecification.fromSync(promptSpecification, + this.immediateExecution)) + .block(); + } + + /** + * List all registered prompts. + * @return A list of all registered prompts + */ + public List listPrompts() { + return this.asyncServer.listPrompts().collectList().block(); + } + + /** + * Remove a prompt handler at runtime. + * @param promptName The name of the prompt handler to remove + */ + public void removePrompt(String promptName) { + this.asyncServer.removePrompt(promptName).block(); + } + + /** + * This method is package-private and used for test only. Should not be called by user + * code. + * @param protocolVersions the Client supported protocol versions. + */ + void setProtocolVersions(List protocolVersions) { + this.asyncServer.setProtocolVersions(protocolVersions); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java similarity index 68% rename from mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java rename to mcp-core/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java index 91f8d9e4c..10f0e5a31 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java @@ -4,6 +4,8 @@ package io.modelcontextprotocol.server; +import java.util.List; + import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; import io.modelcontextprotocol.util.Assert; @@ -54,13 +56,27 @@ public class McpSyncServer { */ private final McpAsyncServer asyncServer; + private final boolean immediateExecution; + /** * Creates a new synchronous server that wraps the provided async server. * @param asyncServer The async server to wrap */ public McpSyncServer(McpAsyncServer asyncServer) { + this(asyncServer, false); + } + + /** + * Creates a new synchronous server that wraps the provided async server. + * @param asyncServer The async server to wrap + * @param immediateExecution Tools, prompts, and resources handlers execute work + * without blocking code offloading. Do NOT set to true if the {@code asyncServer}'s + * transport is non-blocking. + */ + public McpSyncServer(McpAsyncServer asyncServer, boolean immediateExecution) { Assert.notNull(asyncServer, "Async server must not be null"); this.asyncServer = asyncServer; + this.immediateExecution = immediateExecution; } /** @@ -68,7 +84,17 @@ public McpSyncServer(McpAsyncServer asyncServer) { * @param toolHandler The tool handler to add */ public void addTool(McpServerFeatures.SyncToolSpecification toolHandler) { - this.asyncServer.addTool(McpServerFeatures.AsyncToolSpecification.fromSync(toolHandler)).block(); + this.asyncServer + .addTool(McpServerFeatures.AsyncToolSpecification.fromSync(toolHandler, this.immediateExecution)) + .block(); + } + + /** + * List all registered tools. + * @return A list of all registered tools + */ + public List listTools() { + return this.asyncServer.listTools().collectList().block(); } /** @@ -81,10 +107,21 @@ public void removeTool(String toolName) { /** * Add a new resource handler. - * @param resourceHandler The resource handler to add + * @param resourceSpecification The resource specification to add + */ + public void addResource(McpServerFeatures.SyncResourceSpecification resourceSpecification) { + this.asyncServer + .addResource(McpServerFeatures.AsyncResourceSpecification.fromSync(resourceSpecification, + this.immediateExecution)) + .block(); + } + + /** + * List all registered resources. + * @return A list of all registered resources */ - public void addResource(McpServerFeatures.SyncResourceSpecification resourceHandler) { - this.asyncServer.addResource(McpServerFeatures.AsyncResourceSpecification.fromSync(resourceHandler)).block(); + public List listResources() { + return this.asyncServer.listResources().collectList().block(); } /** @@ -95,12 +132,50 @@ public void removeResource(String resourceUri) { this.asyncServer.removeResource(resourceUri).block(); } + /** + * Add a new resource template. + * @param resourceTemplateSpecification The resource template specification to add + */ + public void addResourceTemplate(McpServerFeatures.SyncResourceTemplateSpecification resourceTemplateSpecification) { + this.asyncServer + .addResourceTemplate(McpServerFeatures.AsyncResourceTemplateSpecification + .fromSync(resourceTemplateSpecification, this.immediateExecution)) + .block(); + } + + /** + * List all registered resource templates. + * @return A list of all registered resource templates + */ + public List listResourceTemplates() { + return this.asyncServer.listResourceTemplates().collectList().block(); + } + + /** + * Remove a resource template. + * @param uriTemplate The URI template of the resource template to remove + */ + public void removeResourceTemplate(String uriTemplate) { + this.asyncServer.removeResourceTemplate(uriTemplate).block(); + } + /** * Add a new prompt handler. * @param promptSpecification The prompt specification to add */ public void addPrompt(McpServerFeatures.SyncPromptSpecification promptSpecification) { - this.asyncServer.addPrompt(McpServerFeatures.AsyncPromptSpecification.fromSync(promptSpecification)).block(); + this.asyncServer + .addPrompt( + McpServerFeatures.AsyncPromptSpecification.fromSync(promptSpecification, this.immediateExecution)) + .block(); + } + + /** + * List all registered prompts. + * @return A list of all registered prompts + */ + public List listPrompts() { + return this.asyncServer.listPrompts().collectList().block(); } /** diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java similarity index 80% rename from mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java rename to mcp-core/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java index 25da5a6f9..0b9115b79 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java @@ -4,8 +4,8 @@ package io.modelcontextprotocol.server; +import io.modelcontextprotocol.common.McpTransportContext; import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; /** @@ -28,6 +28,14 @@ public McpSyncServerExchange(McpAsyncServerExchange exchange) { this.exchange = exchange; } + /** + * Provides the Session ID + * @return session ID + */ + public String sessionId() { + return this.exchange.sessionId(); + } + /** * Get the client capabilities that define the supported features and functionality. * @return The client capabilities @@ -44,6 +52,16 @@ public McpSchema.Implementation getClientInfo() { return this.exchange.getClientInfo(); } + /** + * Provides the {@link McpTransportContext} associated with the transport layer. For + * HTTP transports it can contain the metadata associated with the HTTP request that + * triggered the processing. + * @return the transport context object + */ + public McpTransportContext transportContext() { + return this.exchange.transportContext(); + } + /** * Create a new message using the sampling capabilities of the client. The Model * Context Protocol (MCP) provides a standardized way for servers to request LLM @@ -108,4 +126,21 @@ public void loggingNotification(LoggingMessageNotification loggingMessageNotific this.exchange.loggingNotification(loggingMessageNotification).block(); } + /** + * Sends a notification to the client that the current progress status has changed for + * long-running operations. + * @param progressNotification The progress notification to send + */ + public void progressNotification(McpSchema.ProgressNotification progressNotification) { + this.exchange.progressNotification(progressNotification).block(); + } + + /** + * Sends a synchronous ping request to the client. + * @return + */ + public Object ping() { + return this.exchange.ping().block(); + } + } diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpTransportContextExtractor.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpTransportContextExtractor.java new file mode 100644 index 000000000..ea9f05a4f --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpTransportContextExtractor.java @@ -0,0 +1,27 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import io.modelcontextprotocol.common.McpTransportContext; + +/** + * The contract for extracting metadata from a generic transport request of type + * {@link T}. + * + * @param transport-specific representation of the request which allows extracting + * metadata for use in the MCP features implementations. + * @author Dariusz Jędrzejczyk + */ +public interface McpTransportContextExtractor { + + /** + * Extract transport-specific metadata from the request into an McpTransportContext. + * @param request the generic representation for the request in the context of a + * specific transport implementation + * @return the context containing the metadata + */ + McpTransportContext extract(T request); + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java similarity index 71% rename from mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java rename to mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java index afdbff472..96cebb74a 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java @@ -1,24 +1,31 @@ /* * Copyright 2024 - 2024 the original author or authors. */ + package io.modelcontextprotocol.server.transport; import java.io.BufferedReader; import java.io.IOException; import java.io.PrintWriter; +import java.time.Duration; +import java.util.List; import java.util.Map; import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicBoolean; -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.TypeRef; +import io.modelcontextprotocol.server.McpTransportContextExtractor; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpServerSession; import io.modelcontextprotocol.spec.McpServerTransport; import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.spec.ProtocolVersions; import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.KeepAliveScheduler; import jakarta.servlet.AsyncContext; import jakarta.servlet.ServletException; import jakarta.servlet.annotation.WebServlet; @@ -62,7 +69,9 @@ @WebServlet(asyncSupported = true) public class HttpServletSseServerTransportProvider extends HttpServlet implements McpServerTransportProvider { - /** Logger for this class */ + /** + * Logger for this class + */ private static final Logger logger = LoggerFactory.getLogger(HttpServletSseServerTransportProvider.class); public static final String UTF_8 = "UTF-8"; @@ -71,77 +80,112 @@ public class HttpServletSseServerTransportProvider extends HttpServlet implement public static final String FAILED_TO_SEND_ERROR_RESPONSE = "Failed to send error response: {}"; - /** Default endpoint path for SSE connections */ + /** + * Default endpoint path for SSE connections + */ public static final String DEFAULT_SSE_ENDPOINT = "/sse"; - /** Event type for regular messages */ + /** + * Event type for regular messages + */ public static final String MESSAGE_EVENT_TYPE = "message"; - /** Event type for endpoint information */ + /** + * Event type for endpoint information + */ public static final String ENDPOINT_EVENT_TYPE = "endpoint"; + public static final String SESSION_ID = "sessionId"; + public static final String DEFAULT_BASE_URL = ""; - /** JSON object mapper for serialization/deserialization */ - private final ObjectMapper objectMapper; + /** + * JSON mapper for serialization/deserialization + */ + private final McpJsonMapper jsonMapper; - /** Base URL for the server transport */ + /** + * Base URL for the server transport + */ private final String baseUrl; - /** The endpoint path for handling client messages */ + /** + * The endpoint path for handling client messages + */ private final String messageEndpoint; - /** The endpoint path for handling SSE connections */ + /** + * The endpoint path for handling SSE connections + */ private final String sseEndpoint; - /** Map of active client sessions, keyed by session ID */ + /** + * Map of active client sessions, keyed by session ID + */ private final Map sessions = new ConcurrentHashMap<>(); - /** Flag indicating if the transport is in the process of shutting down */ + private McpTransportContextExtractor contextExtractor; + + /** + * Flag indicating if the transport is in the process of shutting down + */ private final AtomicBoolean isClosing = new AtomicBoolean(false); - /** Session factory for creating new sessions */ + /** + * Session factory for creating new sessions + */ private McpServerSession.Factory sessionFactory; /** - * Creates a new HttpServletSseServerTransportProvider instance with a custom SSE - * endpoint. - * @param objectMapper The JSON object mapper to use for message - * serialization/deserialization - * @param messageEndpoint The endpoint path where clients will send their messages - * @param sseEndpoint The endpoint path where clients will establish SSE connections + * Keep-alive scheduler for managing session pings. Activated if keepAliveInterval is + * set. Disabled by default. */ - public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, - String sseEndpoint) { - this(objectMapper, DEFAULT_BASE_URL, messageEndpoint, sseEndpoint); - } + private KeepAliveScheduler keepAliveScheduler; /** * Creates a new HttpServletSseServerTransportProvider instance with a custom SSE * endpoint. - * @param objectMapper The JSON object mapper to use for message + * @param jsonMapper The JSON object mapper to use for message * serialization/deserialization * @param baseUrl The base URL for the server transport * @param messageEndpoint The endpoint path where clients will send their messages * @param sseEndpoint The endpoint path where clients will establish SSE connections + * @param keepAliveInterval The interval for keep-alive pings, or null to disable + * keep-alive functionality + * @param contextExtractor The extractor for transport context from the request. + * @deprecated Use the builder {@link #builder()} instead for better configuration + * options. */ - public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, - String sseEndpoint) { - this.objectMapper = objectMapper; + private HttpServletSseServerTransportProvider(McpJsonMapper jsonMapper, String baseUrl, String messageEndpoint, + String sseEndpoint, Duration keepAliveInterval, + McpTransportContextExtractor contextExtractor) { + + Assert.notNull(jsonMapper, "JsonMapper must not be null"); + Assert.notNull(messageEndpoint, "messageEndpoint must not be null"); + Assert.notNull(sseEndpoint, "sseEndpoint must not be null"); + Assert.notNull(contextExtractor, "Context extractor must not be null"); + + this.jsonMapper = jsonMapper; this.baseUrl = baseUrl; this.messageEndpoint = messageEndpoint; this.sseEndpoint = sseEndpoint; + this.contextExtractor = contextExtractor; + + if (keepAliveInterval != null) { + + this.keepAliveScheduler = KeepAliveScheduler + .builder(() -> (isClosing.get()) ? Flux.empty() : Flux.fromIterable(sessions.values())) + .initialDelay(keepAliveInterval) + .interval(keepAliveInterval) + .build(); + + this.keepAliveScheduler.start(); + } } - /** - * Creates a new HttpServletSseServerTransportProvider instance with the default SSE - * endpoint. - * @param objectMapper The JSON object mapper to use for message - * serialization/deserialization - * @param messageEndpoint The endpoint path where clients will send their messages - */ - public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint) { - this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT); + @Override + public List protocolVersions() { + return List.of(ProtocolVersions.MCP_2024_11_05); } /** @@ -223,7 +267,22 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response) this.sessions.put(sessionId, session); // Send initial endpoint event - this.sendEvent(writer, ENDPOINT_EVENT_TYPE, this.baseUrl + this.messageEndpoint + "?sessionId=" + sessionId); + this.sendEvent(writer, ENDPOINT_EVENT_TYPE, buildEndpointUrl(sessionId)); + } + + /** + * Constructs the full message endpoint URL by combining the base URL, message path, + * and the required session_id query parameter. + * @param sessionId the unique session identifier + * @return the fully qualified endpoint URL as a string + */ + private String buildEndpointUrl(String sessionId) { + // for WebMVC compatibility + if (this.baseUrl.endsWith("/")) { + return this.baseUrl.substring(0, this.baseUrl.length() - 1) + this.messageEndpoint + "?sessionId=" + + sessionId; + } + return this.baseUrl + this.messageEndpoint + "?sessionId=" + sessionId; } /** @@ -258,7 +317,7 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) response.setContentType(APPLICATION_JSON); response.setCharacterEncoding(UTF_8); response.setStatus(HttpServletResponse.SC_BAD_REQUEST); - String jsonError = objectMapper.writeValueAsString(new McpError("Session ID missing in message endpoint")); + String jsonError = jsonMapper.writeValueAsString(new McpError("Session ID missing in message endpoint")); PrintWriter writer = response.getWriter(); writer.write(jsonError); writer.flush(); @@ -271,7 +330,7 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) response.setContentType(APPLICATION_JSON); response.setCharacterEncoding(UTF_8); response.setStatus(HttpServletResponse.SC_NOT_FOUND); - String jsonError = objectMapper.writeValueAsString(new McpError("Session not found: " + sessionId)); + String jsonError = jsonMapper.writeValueAsString(new McpError("Session not found: " + sessionId)); PrintWriter writer = response.getWriter(); writer.write(jsonError); writer.flush(); @@ -286,10 +345,12 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) body.append(line); } - McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body.toString()); + final McpTransportContext transportContext = this.contextExtractor.extract(request); + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(jsonMapper, body.toString()); // Process the message through the session's handle method - session.handle(message).block(); // Block for Servlet compatibility + // Block for Servlet compatibility + session.handle(message).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)).block(); response.setStatus(HttpServletResponse.SC_OK); } @@ -300,7 +361,7 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) response.setContentType(APPLICATION_JSON); response.setCharacterEncoding(UTF_8); response.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR); - String jsonError = objectMapper.writeValueAsString(mcpError); + String jsonError = jsonMapper.writeValueAsString(mcpError); PrintWriter writer = response.getWriter(); writer.write(jsonError); writer.flush(); @@ -324,7 +385,13 @@ public Mono closeGracefully() { isClosing.set(true); logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size()); - return Flux.fromIterable(sessions.values()).flatMap(McpServerSession::closeGracefully).then(); + return Flux.fromIterable(sessions.values()).flatMap(McpServerSession::closeGracefully).then().doOnSuccess(v -> { + sessions.clear(); + logger.debug("Graceful shutdown completed"); + if (this.keepAliveScheduler != null) { + this.keepAliveScheduler.shutdown(); + } + }); } /** @@ -390,7 +457,7 @@ private class HttpServletMcpSessionTransport implements McpServerTransport { public Mono sendMessage(McpSchema.JSONRPCMessage message) { return Mono.fromRunnable(() -> { try { - String jsonText = objectMapper.writeValueAsString(message); + String jsonText = jsonMapper.writeValueAsString(message); sendEvent(writer, MESSAGE_EVENT_TYPE, jsonText); logger.debug("Message sent to session {}", sessionId); } @@ -403,15 +470,15 @@ public Mono sendMessage(McpSchema.JSONRPCMessage message) { } /** - * Converts data from one type to another using the configured ObjectMapper. + * Converts data from one type to another using the configured JsonMapper. * @param data The source data object to convert * @param typeRef The target type reference - * @return The converted object of type T * @param The target type + * @return The converted object of type T */ @Override - public T unmarshalFrom(Object data, TypeReference typeRef) { - return objectMapper.convertValue(data, typeRef); + public T unmarshalFrom(Object data, TypeRef typeRef) { + return jsonMapper.convertValue(data, typeRef); } /** @@ -467,7 +534,7 @@ public static Builder builder() { */ public static class Builder { - private ObjectMapper objectMapper = new ObjectMapper(); + private McpJsonMapper jsonMapper; private String baseUrl = DEFAULT_BASE_URL; @@ -475,14 +542,21 @@ public static class Builder { private String sseEndpoint = DEFAULT_SSE_ENDPOINT; + private McpTransportContextExtractor contextExtractor = ( + serverRequest) -> McpTransportContext.EMPTY; + + private Duration keepAliveInterval; + /** - * Sets the JSON object mapper to use for message serialization/deserialization. - * @param objectMapper The object mapper to use + * Sets the JsonMapper implementation to use for serialization/deserialization. If + * not specified, a JacksonJsonMapper will be created from the configured + * ObjectMapper. + * @param jsonMapper The JsonMapper to use * @return This builder instance for method chaining */ - public Builder objectMapper(ObjectMapper objectMapper) { - Assert.notNull(objectMapper, "ObjectMapper must not be null"); - this.objectMapper = objectMapper; + public Builder jsonMapper(McpJsonMapper jsonMapper) { + Assert.notNull(jsonMapper, "JsonMapper must not be null"); + this.jsonMapper = jsonMapper; return this; } @@ -522,20 +596,44 @@ public Builder sseEndpoint(String sseEndpoint) { return this; } + /** + * Sets the context extractor for extracting transport context from the request. + * @param contextExtractor The context extractor to use. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if contextExtractor is null + */ + public HttpServletSseServerTransportProvider.Builder contextExtractor( + McpTransportContextExtractor contextExtractor) { + Assert.notNull(contextExtractor, "Context extractor must not be null"); + this.contextExtractor = contextExtractor; + return this; + } + + /** + * Sets the interval for keep-alive pings. + *

+ * If not specified, keep-alive pings will be disabled. + * @param keepAliveInterval The interval duration for keep-alive pings + * @return This builder instance for method chaining + */ + public Builder keepAliveInterval(Duration keepAliveInterval) { + this.keepAliveInterval = keepAliveInterval; + return this; + } + /** * Builds a new instance of HttpServletSseServerTransportProvider with the * configured settings. * @return A new HttpServletSseServerTransportProvider instance - * @throws IllegalStateException if objectMapper or messageEndpoint is not set + * @throws IllegalStateException if jsonMapper or messageEndpoint is not set */ public HttpServletSseServerTransportProvider build() { - if (objectMapper == null) { - throw new IllegalStateException("ObjectMapper must be set"); - } if (messageEndpoint == null) { throw new IllegalStateException("MessageEndpoint must be set"); } - return new HttpServletSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint); + return new HttpServletSseServerTransportProvider( + jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, baseUrl, messageEndpoint, sseEndpoint, + keepAliveInterval, contextExtractor); } } diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStatelessServerTransport.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStatelessServerTransport.java new file mode 100644 index 000000000..40767f416 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStatelessServerTransport.java @@ -0,0 +1,305 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.PrintWriter; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import io.modelcontextprotocol.json.McpJsonMapper; + +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.server.McpStatelessServerHandler; +import io.modelcontextprotocol.server.McpTransportContextExtractor; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpStatelessServerTransport; +import io.modelcontextprotocol.util.Assert; +import jakarta.servlet.ServletException; +import jakarta.servlet.annotation.WebServlet; +import jakarta.servlet.http.HttpServlet; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import reactor.core.publisher.Mono; + +/** + * Implementation of an HttpServlet based {@link McpStatelessServerTransport}. + * + * @author Christian Tzolov + * @author Dariusz Jędrzejczyk + */ +@WebServlet(asyncSupported = true) +public class HttpServletStatelessServerTransport extends HttpServlet implements McpStatelessServerTransport { + + private static final Logger logger = LoggerFactory.getLogger(HttpServletStatelessServerTransport.class); + + public static final String UTF_8 = "UTF-8"; + + public static final String APPLICATION_JSON = "application/json"; + + public static final String TEXT_EVENT_STREAM = "text/event-stream"; + + public static final String ACCEPT = "Accept"; + + public static final String FAILED_TO_SEND_ERROR_RESPONSE = "Failed to send error response: {}"; + + private final McpJsonMapper jsonMapper; + + private final String mcpEndpoint; + + private McpStatelessServerHandler mcpHandler; + + private McpTransportContextExtractor contextExtractor; + + private volatile boolean isClosing = false; + + private HttpServletStatelessServerTransport(McpJsonMapper jsonMapper, String mcpEndpoint, + McpTransportContextExtractor contextExtractor) { + Assert.notNull(jsonMapper, "jsonMapper must not be null"); + Assert.notNull(mcpEndpoint, "mcpEndpoint must not be null"); + Assert.notNull(contextExtractor, "contextExtractor must not be null"); + + this.jsonMapper = jsonMapper; + this.mcpEndpoint = mcpEndpoint; + this.contextExtractor = contextExtractor; + } + + @Override + public void setMcpHandler(McpStatelessServerHandler mcpHandler) { + this.mcpHandler = mcpHandler; + } + + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> this.isClosing = true); + } + + /** + * Handles GET requests - returns 405 METHOD NOT ALLOWED as stateless transport + * doesn't support GET requests. + * @param request The HTTP servlet request + * @param response The HTTP servlet response + * @throws ServletException If a servlet-specific error occurs + * @throws IOException If an I/O error occurs + */ + @Override + protected void doGet(HttpServletRequest request, HttpServletResponse response) + throws ServletException, IOException { + + String requestURI = request.getRequestURI(); + if (!requestURI.endsWith(mcpEndpoint)) { + response.sendError(HttpServletResponse.SC_NOT_FOUND); + return; + } + + response.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED); + } + + /** + * Handles POST requests for incoming JSON-RPC messages from clients. + * @param request The HTTP servlet request containing the JSON-RPC message + * @param response The HTTP servlet response + * @throws ServletException If a servlet-specific error occurs + * @throws IOException If an I/O error occurs + */ + @Override + protected void doPost(HttpServletRequest request, HttpServletResponse response) + throws ServletException, IOException { + + String requestURI = request.getRequestURI(); + if (!requestURI.endsWith(mcpEndpoint)) { + response.sendError(HttpServletResponse.SC_NOT_FOUND); + return; + } + + if (isClosing) { + response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE, "Server is shutting down"); + return; + } + + McpTransportContext transportContext = this.contextExtractor.extract(request); + + String accept = request.getHeader(ACCEPT); + if (accept == null || !(accept.contains(APPLICATION_JSON) && accept.contains(TEXT_EVENT_STREAM))) { + this.responseError(response, HttpServletResponse.SC_BAD_REQUEST, + new McpError("Both application/json and text/event-stream required in Accept header")); + return; + } + + try { + BufferedReader reader = request.getReader(); + StringBuilder body = new StringBuilder(); + String line; + while ((line = reader.readLine()) != null) { + body.append(line); + } + + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(jsonMapper, body.toString()); + + if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest) { + try { + McpSchema.JSONRPCResponse jsonrpcResponse = this.mcpHandler + .handleRequest(transportContext, jsonrpcRequest) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) + .block(); + + response.setContentType(APPLICATION_JSON); + response.setCharacterEncoding(UTF_8); + response.setStatus(HttpServletResponse.SC_OK); + + String jsonResponseText = jsonMapper.writeValueAsString(jsonrpcResponse); + PrintWriter writer = response.getWriter(); + writer.write(jsonResponseText); + writer.flush(); + } + catch (Exception e) { + logger.error("Failed to handle request: {}", e.getMessage()); + this.responseError(response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR, + new McpError("Failed to handle request: " + e.getMessage())); + } + } + else if (message instanceof McpSchema.JSONRPCNotification jsonrpcNotification) { + try { + this.mcpHandler.handleNotification(transportContext, jsonrpcNotification) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) + .block(); + response.setStatus(HttpServletResponse.SC_ACCEPTED); + } + catch (Exception e) { + logger.error("Failed to handle notification: {}", e.getMessage()); + this.responseError(response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR, + new McpError("Failed to handle notification: " + e.getMessage())); + } + } + else { + this.responseError(response, HttpServletResponse.SC_BAD_REQUEST, + new McpError("The server accepts either requests or notifications")); + } + } + catch (IllegalArgumentException | IOException e) { + logger.error("Failed to deserialize message: {}", e.getMessage()); + this.responseError(response, HttpServletResponse.SC_BAD_REQUEST, new McpError("Invalid message format")); + } + catch (Exception e) { + logger.error("Unexpected error handling message: {}", e.getMessage()); + this.responseError(response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR, + new McpError("Unexpected error: " + e.getMessage())); + } + } + + /** + * Sends an error response to the client. + * @param response The HTTP servlet response + * @param httpCode The HTTP status code + * @param mcpError The MCP error to send + * @throws IOException If an I/O error occurs + */ + private void responseError(HttpServletResponse response, int httpCode, McpError mcpError) throws IOException { + response.setContentType(APPLICATION_JSON); + response.setCharacterEncoding(UTF_8); + response.setStatus(httpCode); + String jsonError = jsonMapper.writeValueAsString(mcpError); + PrintWriter writer = response.getWriter(); + writer.write(jsonError); + writer.flush(); + } + + /** + * Cleans up resources when the servlet is being destroyed. + *

+ * This method ensures a graceful shutdown before calling the parent's destroy method. + */ + @Override + public void destroy() { + closeGracefully().block(); + super.destroy(); + } + + /** + * Create a builder for the server. + * @return a fresh {@link Builder} instance. + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for creating instances of {@link HttpServletStatelessServerTransport}. + *

+ * This builder provides a fluent API for configuring and creating instances of + * HttpServletStatelessServerTransport with custom settings. + */ + public static class Builder { + + private McpJsonMapper jsonMapper; + + private String mcpEndpoint = "/mcp"; + + private McpTransportContextExtractor contextExtractor = ( + serverRequest) -> McpTransportContext.EMPTY; + + private Builder() { + // used by a static method + } + + /** + * Sets the JsonMapper to use for JSON serialization/deserialization of MCP + * messages. + * @param jsonMapper The JsonMapper instance. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if jsonMapper is null + */ + public Builder jsonMapper(McpJsonMapper jsonMapper) { + Assert.notNull(jsonMapper, "JsonMapper must not be null"); + this.jsonMapper = jsonMapper; + return this; + } + + /** + * Sets the endpoint URI where clients should send their JSON-RPC messages. + * @param messageEndpoint The message endpoint URI. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if messageEndpoint is null + */ + public Builder messageEndpoint(String messageEndpoint) { + Assert.notNull(messageEndpoint, "Message endpoint must not be null"); + this.mcpEndpoint = messageEndpoint; + return this; + } + + /** + * Sets the context extractor that allows providing the MCP feature + * implementations to inspect HTTP transport level metadata that was present at + * HTTP request processing time. This allows to extract custom headers and other + * useful data for use during execution later on in the process. + * @param contextExtractor The contextExtractor to fill in a + * {@link McpTransportContext}. + * @return this builder instance + * @throws IllegalArgumentException if contextExtractor is null + */ + public Builder contextExtractor(McpTransportContextExtractor contextExtractor) { + Assert.notNull(contextExtractor, "Context extractor must not be null"); + this.contextExtractor = contextExtractor; + return this; + } + + /** + * Builds a new instance of {@link HttpServletStatelessServerTransport} with the + * configured settings. + * @return A new HttpServletStatelessServerTransport instance + * @throws IllegalStateException if required parameters are not set + */ + public HttpServletStatelessServerTransport build() { + Assert.notNull(mcpEndpoint, "Message endpoint must be set"); + return new HttpServletStatelessServerTransport(jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, + mcpEndpoint, contextExtractor); + } + + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java new file mode 100644 index 000000000..34671c105 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java @@ -0,0 +1,851 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.PrintWriter; +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.locks.ReentrantLock; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import io.modelcontextprotocol.json.TypeRef; + +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.server.McpTransportContextExtractor; +import io.modelcontextprotocol.spec.HttpHeaders; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpStreamableServerSession; +import io.modelcontextprotocol.spec.McpStreamableServerTransport; +import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.ProtocolVersions; +import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.util.KeepAliveScheduler; +import jakarta.servlet.AsyncContext; +import jakarta.servlet.ServletException; +import jakarta.servlet.annotation.WebServlet; +import jakarta.servlet.http.HttpServlet; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +/** + * Server-side implementation of the Model Context Protocol (MCP) streamable transport + * layer using HTTP with Server-Sent Events (SSE) through HttpServlet. This implementation + * provides a bridge between synchronous HttpServlet operations and reactive programming + * patterns to maintain compatibility with the reactive transport interface. + * + *

+ * This is the HttpServlet equivalent of + * {@link io.modelcontextprotocol.server.transport.WebMvcStreamableServerTransportProvider} + * for the core MCP module, providing streamable HTTP transport functionality without + * Spring dependencies. + * + * @author Zachary German + * @author Christian Tzolov + * @author Dariusz Jędrzejczyk + * @see McpStreamableServerTransportProvider + * @see HttpServlet + */ +@WebServlet(asyncSupported = true) +public class HttpServletStreamableServerTransportProvider extends HttpServlet + implements McpStreamableServerTransportProvider { + + private static final Logger logger = LoggerFactory.getLogger(HttpServletStreamableServerTransportProvider.class); + + /** + * Event type for JSON-RPC messages sent through the SSE connection. + */ + public static final String MESSAGE_EVENT_TYPE = "message"; + + /** + * Event type for sending the message endpoint URI to clients. + */ + public static final String ENDPOINT_EVENT_TYPE = "endpoint"; + + /** + * Header name for the response media types accepted by the requester. + */ + private static final String ACCEPT = "Accept"; + + public static final String UTF_8 = "UTF-8"; + + public static final String APPLICATION_JSON = "application/json"; + + public static final String TEXT_EVENT_STREAM = "text/event-stream"; + + public static final String FAILED_TO_SEND_ERROR_RESPONSE = "Failed to send error response: {}"; + + /** + * The endpoint URI where clients should send their JSON-RPC messages. Defaults to + * "/mcp". + */ + private final String mcpEndpoint; + + /** + * Flag indicating whether DELETE requests are disallowed on the endpoint. + */ + private final boolean disallowDelete; + + private final McpJsonMapper jsonMapper; + + private McpStreamableServerSession.Factory sessionFactory; + + /** + * Map of active client sessions, keyed by mcp-session-id. + */ + private final ConcurrentHashMap sessions = new ConcurrentHashMap<>(); + + private McpTransportContextExtractor contextExtractor; + + /** + * Flag indicating if the transport is shutting down. + */ + private volatile boolean isClosing = false; + + /** + * Keep-alive scheduler for managing session pings. Activated if keepAliveInterval is + * set. Disabled by default. + */ + private KeepAliveScheduler keepAliveScheduler; + + /** + * Constructs a new HttpServletStreamableServerTransportProvider instance. + * @param jsonMapper The JsonMapper to use for JSON serialization/deserialization of + * messages. + * @param mcpEndpoint The endpoint URI where clients should send their JSON-RPC + * messages via HTTP. This endpoint will handle GET, POST, and DELETE requests. + * @param disallowDelete Whether to disallow DELETE requests on the endpoint. + * @param contextExtractor The extractor for transport context from the request. + * @throws IllegalArgumentException if any parameter is null + */ + private HttpServletStreamableServerTransportProvider(McpJsonMapper jsonMapper, String mcpEndpoint, + boolean disallowDelete, McpTransportContextExtractor contextExtractor, + Duration keepAliveInterval) { + Assert.notNull(jsonMapper, "JsonMapper must not be null"); + Assert.notNull(mcpEndpoint, "MCP endpoint must not be null"); + Assert.notNull(contextExtractor, "Context extractor must not be null"); + + this.jsonMapper = jsonMapper; + this.mcpEndpoint = mcpEndpoint; + this.disallowDelete = disallowDelete; + this.contextExtractor = contextExtractor; + + if (keepAliveInterval != null) { + + this.keepAliveScheduler = KeepAliveScheduler + .builder(() -> (isClosing) ? Flux.empty() : Flux.fromIterable(sessions.values())) + .initialDelay(keepAliveInterval) + .interval(keepAliveInterval) + .build(); + + this.keepAliveScheduler.start(); + } + + } + + @Override + public List protocolVersions() { + return List.of(ProtocolVersions.MCP_2024_11_05, ProtocolVersions.MCP_2025_03_26, + ProtocolVersions.MCP_2025_06_18); + } + + @Override + public void setSessionFactory(McpStreamableServerSession.Factory sessionFactory) { + this.sessionFactory = sessionFactory; + } + + /** + * Broadcasts a notification to all connected clients through their SSE connections. + * If any errors occur during sending to a particular client, they are logged but + * don't prevent sending to other clients. + * @param method The method name for the notification + * @param params The parameters for the notification + * @return A Mono that completes when the broadcast attempt is finished + */ + @Override + public Mono notifyClients(String method, Object params) { + if (this.sessions.isEmpty()) { + logger.debug("No active sessions to broadcast message to"); + return Mono.empty(); + } + + logger.debug("Attempting to broadcast message to {} active sessions", this.sessions.size()); + + return Mono.fromRunnable(() -> { + this.sessions.values().parallelStream().forEach(session -> { + try { + session.sendNotification(method, params).block(); + } + catch (Exception e) { + logger.error("Failed to send message to session {}: {}", session.getId(), e.getMessage()); + } + }); + }); + } + + /** + * Initiates a graceful shutdown of the transport. + * @return A Mono that completes when all cleanup operations are finished + */ + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> { + this.isClosing = true; + logger.debug("Initiating graceful shutdown with {} active sessions", this.sessions.size()); + + this.sessions.values().parallelStream().forEach(session -> { + try { + session.closeGracefully().block(); + } + catch (Exception e) { + logger.error("Failed to close session {}: {}", session.getId(), e.getMessage()); + } + }); + + this.sessions.clear(); + logger.debug("Graceful shutdown completed"); + }).then().doOnSuccess(v -> { + sessions.clear(); + logger.debug("Graceful shutdown completed"); + if (this.keepAliveScheduler != null) { + this.keepAliveScheduler.shutdown(); + } + }); + } + + /** + * Handles GET requests to establish SSE connections and message replay. + * @param request The HTTP servlet request + * @param response The HTTP servlet response + * @throws ServletException If a servlet-specific error occurs + * @throws IOException If an I/O error occurs + */ + @Override + protected void doGet(HttpServletRequest request, HttpServletResponse response) + throws ServletException, IOException { + + String requestURI = request.getRequestURI(); + if (!requestURI.endsWith(mcpEndpoint)) { + response.sendError(HttpServletResponse.SC_NOT_FOUND); + return; + } + + if (this.isClosing) { + response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE, "Server is shutting down"); + return; + } + + List badRequestErrors = new ArrayList<>(); + + String accept = request.getHeader(ACCEPT); + if (accept == null || !accept.contains(TEXT_EVENT_STREAM)) { + badRequestErrors.add("text/event-stream required in Accept header"); + } + + String sessionId = request.getHeader(HttpHeaders.MCP_SESSION_ID); + + if (sessionId == null || sessionId.isBlank()) { + badRequestErrors.add("Session ID required in mcp-session-id header"); + } + + if (!badRequestErrors.isEmpty()) { + String combinedMessage = String.join("; ", badRequestErrors); + this.responseError(response, HttpServletResponse.SC_BAD_REQUEST, new McpError(combinedMessage)); + return; + } + + McpStreamableServerSession session = this.sessions.get(sessionId); + + if (session == null) { + response.sendError(HttpServletResponse.SC_NOT_FOUND); + return; + } + + logger.debug("Handling GET request for session: {}", sessionId); + + McpTransportContext transportContext = this.contextExtractor.extract(request); + + try { + response.setContentType(TEXT_EVENT_STREAM); + response.setCharacterEncoding(UTF_8); + response.setHeader("Cache-Control", "no-cache"); + response.setHeader("Connection", "keep-alive"); + response.setHeader("Access-Control-Allow-Origin", "*"); + + AsyncContext asyncContext = request.startAsync(); + asyncContext.setTimeout(0); + + HttpServletStreamableMcpSessionTransport sessionTransport = new HttpServletStreamableMcpSessionTransport( + sessionId, asyncContext, response.getWriter()); + + // Check if this is a replay request + if (request.getHeader(HttpHeaders.LAST_EVENT_ID) != null) { + String lastId = request.getHeader(HttpHeaders.LAST_EVENT_ID); + + try { + session.replay(lastId) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) + .toIterable() + .forEach(message -> { + try { + sessionTransport.sendMessage(message) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) + .block(); + } + catch (Exception e) { + logger.error("Failed to replay message: {}", e.getMessage()); + asyncContext.complete(); + } + }); + } + catch (Exception e) { + logger.error("Failed to replay messages: {}", e.getMessage()); + asyncContext.complete(); + } + } + else { + // Establish new listening stream + McpStreamableServerSession.McpStreamableServerSessionStream listeningStream = session + .listeningStream(sessionTransport); + + asyncContext.addListener(new jakarta.servlet.AsyncListener() { + @Override + public void onComplete(jakarta.servlet.AsyncEvent event) throws IOException { + logger.debug("SSE connection completed for session: {}", sessionId); + listeningStream.close(); + } + + @Override + public void onTimeout(jakarta.servlet.AsyncEvent event) throws IOException { + logger.debug("SSE connection timed out for session: {}", sessionId); + listeningStream.close(); + } + + @Override + public void onError(jakarta.servlet.AsyncEvent event) throws IOException { + logger.debug("SSE connection error for session: {}", sessionId); + listeningStream.close(); + } + + @Override + public void onStartAsync(jakarta.servlet.AsyncEvent event) throws IOException { + // No action needed + } + }); + } + } + catch (Exception e) { + logger.error("Failed to handle GET request for session {}: {}", sessionId, e.getMessage()); + response.sendError(HttpServletResponse.SC_INTERNAL_SERVER_ERROR); + } + } + + /** + * Handles POST requests for incoming JSON-RPC messages from clients. + * @param request The HTTP servlet request containing the JSON-RPC message + * @param response The HTTP servlet response + * @throws ServletException If a servlet-specific error occurs + * @throws IOException If an I/O error occurs + */ + @Override + protected void doPost(HttpServletRequest request, HttpServletResponse response) + throws ServletException, IOException { + + String requestURI = request.getRequestURI(); + if (!requestURI.endsWith(mcpEndpoint)) { + response.sendError(HttpServletResponse.SC_NOT_FOUND); + return; + } + + if (this.isClosing) { + response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE, "Server is shutting down"); + return; + } + + List badRequestErrors = new ArrayList<>(); + + String accept = request.getHeader(ACCEPT); + if (accept == null || !accept.contains(TEXT_EVENT_STREAM)) { + badRequestErrors.add("text/event-stream required in Accept header"); + } + if (accept == null || !accept.contains(APPLICATION_JSON)) { + badRequestErrors.add("application/json required in Accept header"); + } + + McpTransportContext transportContext = this.contextExtractor.extract(request); + + try { + BufferedReader reader = request.getReader(); + StringBuilder body = new StringBuilder(); + String line; + while ((line = reader.readLine()) != null) { + body.append(line); + } + + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(jsonMapper, body.toString()); + + // Handle initialization request + if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest + && jsonrpcRequest.method().equals(McpSchema.METHOD_INITIALIZE)) { + if (!badRequestErrors.isEmpty()) { + String combinedMessage = String.join("; ", badRequestErrors); + this.responseError(response, HttpServletResponse.SC_BAD_REQUEST, new McpError(combinedMessage)); + return; + } + + McpSchema.InitializeRequest initializeRequest = jsonMapper.convertValue(jsonrpcRequest.params(), + new TypeRef() { + }); + McpStreamableServerSession.McpStreamableServerSessionInit init = this.sessionFactory + .startSession(initializeRequest); + this.sessions.put(init.session().getId(), init.session()); + + try { + McpSchema.InitializeResult initResult = init.initResult().block(); + + response.setContentType(APPLICATION_JSON); + response.setCharacterEncoding(UTF_8); + response.setHeader(HttpHeaders.MCP_SESSION_ID, init.session().getId()); + response.setStatus(HttpServletResponse.SC_OK); + + String jsonResponse = jsonMapper.writeValueAsString(new McpSchema.JSONRPCResponse( + McpSchema.JSONRPC_VERSION, jsonrpcRequest.id(), initResult, null)); + + PrintWriter writer = response.getWriter(); + writer.write(jsonResponse); + writer.flush(); + return; + } + catch (Exception e) { + logger.error("Failed to initialize session: {}", e.getMessage()); + this.responseError(response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR, + new McpError("Failed to initialize session: " + e.getMessage())); + return; + } + } + + String sessionId = request.getHeader(HttpHeaders.MCP_SESSION_ID); + + if (sessionId == null || sessionId.isBlank()) { + badRequestErrors.add("Session ID required in mcp-session-id header"); + } + + if (!badRequestErrors.isEmpty()) { + String combinedMessage = String.join("; ", badRequestErrors); + this.responseError(response, HttpServletResponse.SC_BAD_REQUEST, new McpError(combinedMessage)); + return; + } + + McpStreamableServerSession session = this.sessions.get(sessionId); + + if (session == null) { + this.responseError(response, HttpServletResponse.SC_NOT_FOUND, + new McpError("Session not found: " + sessionId)); + return; + } + + if (message instanceof McpSchema.JSONRPCResponse jsonrpcResponse) { + session.accept(jsonrpcResponse) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) + .block(); + response.setStatus(HttpServletResponse.SC_ACCEPTED); + } + else if (message instanceof McpSchema.JSONRPCNotification jsonrpcNotification) { + session.accept(jsonrpcNotification) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) + .block(); + response.setStatus(HttpServletResponse.SC_ACCEPTED); + } + else if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest) { + // For streaming responses, we need to return SSE + response.setContentType(TEXT_EVENT_STREAM); + response.setCharacterEncoding(UTF_8); + response.setHeader("Cache-Control", "no-cache"); + response.setHeader("Connection", "keep-alive"); + response.setHeader("Access-Control-Allow-Origin", "*"); + + AsyncContext asyncContext = request.startAsync(); + asyncContext.setTimeout(0); + + HttpServletStreamableMcpSessionTransport sessionTransport = new HttpServletStreamableMcpSessionTransport( + sessionId, asyncContext, response.getWriter()); + + try { + session.responseStream(jsonrpcRequest, sessionTransport) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) + .block(); + } + catch (Exception e) { + logger.error("Failed to handle request stream: {}", e.getMessage()); + asyncContext.complete(); + } + } + else { + this.responseError(response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR, + new McpError("Unknown message type")); + } + } + catch (IllegalArgumentException | IOException e) { + logger.error("Failed to deserialize message: {}", e.getMessage()); + this.responseError(response, HttpServletResponse.SC_BAD_REQUEST, + new McpError("Invalid message format: " + e.getMessage())); + } + catch (Exception e) { + logger.error("Error handling message: {}", e.getMessage()); + try { + this.responseError(response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR, + new McpError("Error processing message: " + e.getMessage())); + } + catch (IOException ex) { + logger.error(FAILED_TO_SEND_ERROR_RESPONSE, ex.getMessage()); + response.sendError(HttpServletResponse.SC_INTERNAL_SERVER_ERROR, "Error processing message"); + } + } + } + + /** + * Handles DELETE requests for session deletion. + * @param request The HTTP servlet request + * @param response The HTTP servlet response + * @throws ServletException If a servlet-specific error occurs + * @throws IOException If an I/O error occurs + */ + @Override + protected void doDelete(HttpServletRequest request, HttpServletResponse response) + throws ServletException, IOException { + + String requestURI = request.getRequestURI(); + if (!requestURI.endsWith(mcpEndpoint)) { + response.sendError(HttpServletResponse.SC_NOT_FOUND); + return; + } + + if (this.isClosing) { + response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE, "Server is shutting down"); + return; + } + + if (this.disallowDelete) { + response.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED); + return; + } + + McpTransportContext transportContext = this.contextExtractor.extract(request); + + if (request.getHeader(HttpHeaders.MCP_SESSION_ID) == null) { + this.responseError(response, HttpServletResponse.SC_BAD_REQUEST, + new McpError("Session ID required in mcp-session-id header")); + return; + } + + String sessionId = request.getHeader(HttpHeaders.MCP_SESSION_ID); + McpStreamableServerSession session = this.sessions.get(sessionId); + + if (session == null) { + response.sendError(HttpServletResponse.SC_NOT_FOUND); + return; + } + + try { + session.delete().contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)).block(); + this.sessions.remove(sessionId); + response.setStatus(HttpServletResponse.SC_OK); + } + catch (Exception e) { + logger.error("Failed to delete session {}: {}", sessionId, e.getMessage()); + try { + this.responseError(response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR, + new McpError(e.getMessage())); + } + catch (IOException ex) { + logger.error(FAILED_TO_SEND_ERROR_RESPONSE, ex.getMessage()); + response.sendError(HttpServletResponse.SC_INTERNAL_SERVER_ERROR, "Error deleting session"); + } + } + } + + public void responseError(HttpServletResponse response, int httpCode, McpError mcpError) throws IOException { + response.setContentType(APPLICATION_JSON); + response.setCharacterEncoding(UTF_8); + response.setStatus(httpCode); + String jsonError = jsonMapper.writeValueAsString(mcpError); + PrintWriter writer = response.getWriter(); + writer.write(jsonError); + writer.flush(); + return; + } + + /** + * Sends an SSE event to a client with a specific ID. + * @param writer The writer to send the event through + * @param eventType The type of event (message or endpoint) + * @param data The event data + * @param id The event ID + * @throws IOException If an error occurs while writing the event + */ + private void sendEvent(PrintWriter writer, String eventType, String data, String id) throws IOException { + if (id != null) { + writer.write("id: " + id + "\n"); + } + writer.write("event: " + eventType + "\n"); + writer.write("data: " + data + "\n\n"); + writer.flush(); + + if (writer.checkError()) { + throw new IOException("Client disconnected"); + } + } + + /** + * Cleans up resources when the servlet is being destroyed. + *

+ * This method ensures a graceful shutdown by closing all client connections before + * calling the parent's destroy method. + */ + @Override + public void destroy() { + closeGracefully().block(); + super.destroy(); + } + + /** + * Implementation of McpStreamableServerTransport for HttpServlet SSE sessions. This + * class handles the transport-level communication for a specific client session. + * + *

+ * This class is thread-safe and uses a ReentrantLock to synchronize access to the + * underlying PrintWriter to prevent race conditions when multiple threads attempt to + * send messages concurrently. + */ + + private class HttpServletStreamableMcpSessionTransport implements McpStreamableServerTransport { + + private final String sessionId; + + private final AsyncContext asyncContext; + + private final PrintWriter writer; + + private volatile boolean closed = false; + + private final ReentrantLock lock = new ReentrantLock(); + + /** + * Creates a new session transport with the specified ID and SSE writer. + * @param sessionId The unique identifier for this session + * @param asyncContext The async context for the session + * @param writer The writer for sending server events to the client + */ + HttpServletStreamableMcpSessionTransport(String sessionId, AsyncContext asyncContext, PrintWriter writer) { + this.sessionId = sessionId; + this.asyncContext = asyncContext; + this.writer = writer; + logger.debug("Streamable session transport {} initialized with SSE writer", sessionId); + } + + /** + * Sends a JSON-RPC message to the client through the SSE connection. + * @param message The JSON-RPC message to send + * @return A Mono that completes when the message has been sent + */ + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message) { + return sendMessage(message, null); + } + + /** + * Sends a JSON-RPC message to the client through the SSE connection with a + * specific message ID. + * @param message The JSON-RPC message to send + * @param messageId The message ID for SSE event identification + * @return A Mono that completes when the message has been sent + */ + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message, String messageId) { + return Mono.fromRunnable(() -> { + if (this.closed) { + logger.debug("Attempted to send message to closed session: {}", this.sessionId); + return; + } + + lock.lock(); + try { + if (this.closed) { + logger.debug("Session {} was closed during message send attempt", this.sessionId); + return; + } + + String jsonText = jsonMapper.writeValueAsString(message); + HttpServletStreamableServerTransportProvider.this.sendEvent(writer, MESSAGE_EVENT_TYPE, jsonText, + messageId != null ? messageId : this.sessionId); + logger.debug("Message sent to session {} with ID {}", this.sessionId, messageId); + } + catch (Exception e) { + logger.error("Failed to send message to session {}: {}", this.sessionId, e.getMessage()); + HttpServletStreamableServerTransportProvider.this.sessions.remove(this.sessionId); + this.asyncContext.complete(); + } + finally { + lock.unlock(); + } + }); + } + + /** + * Converts data from one type to another using the configured JsonMapper. + * @param data The source data object to convert + * @param typeRef The target type reference + * @return The converted object of type T + * @param The target type + */ + @Override + public T unmarshalFrom(Object data, TypeRef typeRef) { + return jsonMapper.convertValue(data, typeRef); + } + + /** + * Initiates a graceful shutdown of the transport. + * @return A Mono that completes when the shutdown is complete + */ + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> { + HttpServletStreamableMcpSessionTransport.this.close(); + }); + } + + /** + * Closes the transport immediately. + */ + @Override + public void close() { + lock.lock(); + try { + if (this.closed) { + logger.debug("Session transport {} already closed", this.sessionId); + return; + } + + this.closed = true; + + // HttpServletStreamableServerTransportProvider.this.sessions.remove(this.sessionId); + this.asyncContext.complete(); + logger.debug("Successfully completed async context for session {}", sessionId); + } + catch (Exception e) { + logger.warn("Failed to complete async context for session {}: {}", sessionId, e.getMessage()); + } + finally { + lock.unlock(); + } + } + + } + + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for creating instances of + * {@link HttpServletStreamableServerTransportProvider}. + */ + public static class Builder { + + private McpJsonMapper jsonMapper; + + private String mcpEndpoint = "/mcp"; + + private boolean disallowDelete = false; + + private McpTransportContextExtractor contextExtractor = ( + serverRequest) -> McpTransportContext.EMPTY; + + private Duration keepAliveInterval; + + /** + * Sets the JsonMapper to use for JSON serialization/deserialization of MCP + * messages. + * @param jsonMapper The JsonMapper instance. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if JsonMapper is null + */ + public Builder jsonMapper(McpJsonMapper jsonMapper) { + Assert.notNull(jsonMapper, "JsonMapper must not be null"); + this.jsonMapper = jsonMapper; + return this; + } + + /** + * Sets the endpoint URI where clients should send their JSON-RPC messages. + * @param mcpEndpoint The MCP endpoint URI. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if mcpEndpoint is null + */ + public Builder mcpEndpoint(String mcpEndpoint) { + Assert.notNull(mcpEndpoint, "MCP endpoint must not be null"); + this.mcpEndpoint = mcpEndpoint; + return this; + } + + /** + * Sets whether to disallow DELETE requests on the endpoint. + * @param disallowDelete true to disallow DELETE requests, false otherwise + * @return this builder instance + */ + public Builder disallowDelete(boolean disallowDelete) { + this.disallowDelete = disallowDelete; + return this; + } + + /** + * Sets the context extractor for extracting transport context from the request. + * @param contextExtractor The context extractor to use. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if contextExtractor is null + */ + public Builder contextExtractor(McpTransportContextExtractor contextExtractor) { + Assert.notNull(contextExtractor, "Context extractor must not be null"); + this.contextExtractor = contextExtractor; + return this; + } + + /** + * Sets the keep-alive interval for the transport. If set, a keep-alive scheduler + * will be activated to periodically ping active sessions. + * @param keepAliveInterval The interval for keep-alive pings. If null, no + * keep-alive will be scheduled. + * @return this builder instance + */ + public Builder keepAliveInterval(Duration keepAliveInterval) { + this.keepAliveInterval = keepAliveInterval; + return this; + } + + /** + * Builds a new instance of {@link HttpServletStreamableServerTransportProvider} + * with the configured settings. + * @return A new HttpServletStreamableServerTransportProvider instance + * @throws IllegalStateException if required parameters are not set + */ + public HttpServletStreamableServerTransportProvider build() { + Assert.notNull(this.mcpEndpoint, "MCP endpoint must be set"); + return new HttpServletStreamableServerTransportProvider( + jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, mcpEndpoint, disallowDelete, + contextExtractor, keepAliveInterval); + } + + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java similarity index 88% rename from mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java rename to mcp-core/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java index 819da9777..68be62931 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java @@ -9,22 +9,22 @@ import java.io.InputStream; import java.io.InputStreamReader; import java.io.OutputStream; -import java.io.Reader; import java.nio.charset.StandardCharsets; -import java.util.Map; +import java.util.List; import java.util.concurrent.Executors; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Function; -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.json.TypeRef; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; import io.modelcontextprotocol.spec.McpServerSession; import io.modelcontextprotocol.spec.McpServerTransport; import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.spec.ProtocolVersions; import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.json.McpJsonMapper; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; @@ -44,7 +44,7 @@ public class StdioServerTransportProvider implements McpServerTransportProvider private static final Logger logger = LoggerFactory.getLogger(StdioServerTransportProvider.class); - private final ObjectMapper objectMapper; + private final McpJsonMapper jsonMapper; private final InputStream inputStream; @@ -56,40 +56,37 @@ public class StdioServerTransportProvider implements McpServerTransportProvider private final Sinks.One inboundReady = Sinks.one(); - /** - * Creates a new StdioServerTransportProvider with a default ObjectMapper and System - * streams. - */ - public StdioServerTransportProvider() { - this(new ObjectMapper()); - } - /** * Creates a new StdioServerTransportProvider with the specified ObjectMapper and * System streams. - * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + * @param jsonMapper The JsonMapper to use for JSON serialization/deserialization */ - public StdioServerTransportProvider(ObjectMapper objectMapper) { - this(objectMapper, System.in, System.out); + public StdioServerTransportProvider(McpJsonMapper jsonMapper) { + this(jsonMapper, System.in, System.out); } /** * Creates a new StdioServerTransportProvider with the specified ObjectMapper and * streams. - * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + * @param jsonMapper The JsonMapper to use for JSON serialization/deserialization * @param inputStream The input stream to read from * @param outputStream The output stream to write to */ - public StdioServerTransportProvider(ObjectMapper objectMapper, InputStream inputStream, OutputStream outputStream) { - Assert.notNull(objectMapper, "The ObjectMapper can not be null"); + public StdioServerTransportProvider(McpJsonMapper jsonMapper, InputStream inputStream, OutputStream outputStream) { + Assert.notNull(jsonMapper, "The JsonMapper can not be null"); Assert.notNull(inputStream, "The InputStream can not be null"); Assert.notNull(outputStream, "The OutputStream can not be null"); - this.objectMapper = objectMapper; + this.jsonMapper = jsonMapper; this.inputStream = inputStream; this.outputStream = outputStream; } + @Override + public List protocolVersions() { + return List.of(ProtocolVersions.MCP_2024_11_05); + } + @Override public void setSessionFactory(McpServerSession.Factory sessionFactory) { // Create a single session for the stdio connection @@ -160,8 +157,8 @@ public Mono sendMessage(McpSchema.JSONRPCMessage message) { } @Override - public T unmarshalFrom(Object data, TypeReference typeRef) { - return objectMapper.convertValue(data, typeRef); + public T unmarshalFrom(Object data, TypeRef typeRef) { + return jsonMapper.convertValue(data, typeRef); } @Override @@ -214,7 +211,7 @@ private void startInboundProcessing() { logger.debug("Received JSON message: {}", line); try { - McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(jsonMapper, line); if (!this.inboundSink.tryEmitNext(message).isSuccess()) { // logIfNotClosing("Failed to enqueue message"); @@ -258,7 +255,7 @@ private void startOutboundProcessing() { .handle((message, sink) -> { if (message != null && !isClosing.get()) { try { - String jsonMessage = objectMapper.writeValueAsString(message); + String jsonMessage = jsonMapper.writeValueAsString(message); // Escape any embedded newlines in the JSON message as per spec jsonMessage = jsonMessage.replace("\r\n", "\\n").replace("\n", "\\n").replace("\r", "\\n"); diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/ClosedMcpTransportSession.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/ClosedMcpTransportSession.java new file mode 100644 index 000000000..b18364abb --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/ClosedMcpTransportSession.java @@ -0,0 +1,58 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ +package io.modelcontextprotocol.spec; + +import java.util.Optional; + +import org.reactivestreams.Publisher; +import reactor.core.publisher.Mono; +import reactor.util.annotation.Nullable; + +/** + * Represents a closed MCP session, which may not be reused. All calls will throw a + * {@link McpTransportSessionClosedException}. + * + * @param the resource representing the connection that the transport + * manages. + * @author Daniel Garnier-Moiroux + */ +public class ClosedMcpTransportSession implements McpTransportSession { + + private final String sessionId; + + public ClosedMcpTransportSession(@Nullable String sessionId) { + this.sessionId = sessionId; + } + + @Override + public Optional sessionId() { + throw new McpTransportSessionClosedException(sessionId); + } + + @Override + public boolean markInitialized(String sessionId) { + throw new McpTransportSessionClosedException(sessionId); + } + + @Override + public void addConnection(CONNECTION connection) { + throw new McpTransportSessionClosedException(sessionId); + } + + @Override + public void removeConnection(CONNECTION connection) { + throw new McpTransportSessionClosedException(sessionId); + } + + @Override + public void close() { + + } + + @Override + public Publisher closeGracefully() { + return Mono.empty(); + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/DefaultMcpStreamableServerSessionFactory.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/DefaultMcpStreamableServerSessionFactory.java new file mode 100644 index 000000000..f497afd43 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/DefaultMcpStreamableServerSessionFactory.java @@ -0,0 +1,55 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.spec; + +import io.modelcontextprotocol.server.McpNotificationHandler; +import io.modelcontextprotocol.server.McpRequestHandler; + +import java.time.Duration; +import java.util.Map; +import java.util.UUID; + +/** + * A default implementation of {@link McpStreamableServerSession.Factory}. + * + * @author Dariusz Jędrzejczyk + */ +public class DefaultMcpStreamableServerSessionFactory implements McpStreamableServerSession.Factory { + + Duration requestTimeout; + + McpStreamableServerSession.InitRequestHandler initRequestHandler; + + Map> requestHandlers; + + Map notificationHandlers; + + /** + * Constructs an instance + * @param requestTimeout timeout for requests + * @param initRequestHandler initialization request handler + * @param requestHandlers map of MCP request handlers keyed by method name + * @param notificationHandlers map of MCP notification handlers keyed by method name + */ + public DefaultMcpStreamableServerSessionFactory(Duration requestTimeout, + McpStreamableServerSession.InitRequestHandler initRequestHandler, + Map> requestHandlers, + Map notificationHandlers) { + this.requestTimeout = requestTimeout; + this.initRequestHandler = initRequestHandler; + this.requestHandlers = requestHandlers; + this.notificationHandlers = notificationHandlers; + } + + @Override + public McpStreamableServerSession.McpStreamableServerSessionInit startSession( + McpSchema.InitializeRequest initializeRequest) { + return new McpStreamableServerSession.McpStreamableServerSessionInit( + new McpStreamableServerSession(UUID.randomUUID().toString(), initializeRequest.capabilities(), + initializeRequest.clientInfo(), requestTimeout, requestHandlers, notificationHandlers), + this.initRequestHandler.handle(initializeRequest)); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportSession.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportSession.java similarity index 97% rename from mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportSession.java rename to mcp-core/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportSession.java index 56cdeaf7f..fdb7bfd89 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportSession.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportSession.java @@ -1,3 +1,7 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + package io.modelcontextprotocol.spec; import org.reactivestreams.Publisher; diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportStream.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportStream.java similarity index 78% rename from mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportStream.java rename to mcp-core/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportStream.java index ecc6f8666..8d63fb50d 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportStream.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportStream.java @@ -1,3 +1,7 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + package io.modelcontextprotocol.spec; import org.reactivestreams.Publisher; @@ -61,14 +65,19 @@ public long streamId() { @Override public Publisher consumeSseStream( Publisher, Iterable>> eventStream) { - return Flux.deferContextual(ctx -> Flux.from(eventStream).doOnError(e -> { - if (resumable && !(e instanceof McpTransportSessionNotFoundException)) { - Mono.from(reconnect.apply(this)).contextWrite(ctx).subscribe(); - } - }).doOnNext(idAndMessage -> idAndMessage.getT1().ifPresent(id -> { - String previousId = this.lastId.getAndSet(id); - logger.debug("Updating last id {} -> {} for stream {}", previousId, id, this.streamId); - })).flatMapIterable(Tuple2::getT2)); + + // @formatter:off + return Flux.deferContextual(ctx -> Flux.from(eventStream) + .doOnNext(idAndMessage -> idAndMessage.getT1().ifPresent(id -> { + String previousId = this.lastId.getAndSet(id); + logger.debug("Updating last id {} -> {} for stream {}", previousId, id, this.streamId); + })) + .doOnError(e -> { + if (resumable && !(e instanceof McpTransportSessionNotFoundException)) { + Mono.from(reconnect.apply(this)).contextWrite(ctx).subscribe(); + } + }) + .flatMapIterable(Tuple2::getT2)); // @formatter:on } } diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/HttpHeaders.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/HttpHeaders.java new file mode 100644 index 000000000..6afc2c119 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/HttpHeaders.java @@ -0,0 +1,56 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.spec; + +/** + * Names of HTTP headers in use by MCP HTTP transports. + * + * @author Dariusz Jędrzejczyk + */ +public interface HttpHeaders { + + /** + * Identifies individual MCP sessions. + */ + String MCP_SESSION_ID = "Mcp-Session-Id"; + + /** + * Identifies events within an SSE Stream. + */ + String LAST_EVENT_ID = "Last-Event-ID"; + + /** + * Identifies the MCP protocol version. + */ + String PROTOCOL_VERSION = "MCP-Protocol-Version"; + + /** + * The HTTP Content-Length header. + * @see RFC9110 + */ + String CONTENT_LENGTH = "Content-Length"; + + /** + * The HTTP Content-Type header. + * @see RFC9110 + */ + String CONTENT_TYPE = "Content-Type"; + + /** + * The HTTP Accept header. + * @see RFC9110 + */ + String ACCEPT = "Accept"; + + /** + * The HTTP Cache-Control header. + * @see RFC9111 + */ + String CACHE_CONTROL = "Cache-Control"; + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/JsonSchemaValidator.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/JsonSchemaValidator.java new file mode 100644 index 000000000..4a42c9ff3 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/JsonSchemaValidator.java @@ -0,0 +1,45 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.spec; + +import java.util.Map; + +/** + * Interface for validating structured content against a JSON schema. This interface + * defines a method to validate structured content based on the provided output schema. + * + * @author Christian Tzolov + */ +public interface JsonSchemaValidator { + + /** + * Represents the result of a validation operation. + * + * @param valid Indicates whether the validation was successful. + * @param errorMessage An error message if the validation failed, otherwise null. + * @param jsonStructuredOutput The text structured content in JSON format if the + * validation was successful, otherwise null. + */ + public record ValidationResponse(boolean valid, String errorMessage, String jsonStructuredOutput) { + + public static ValidationResponse asValid(String jsonStructuredOutput) { + return new ValidationResponse(true, null, jsonStructuredOutput); + } + + public static ValidationResponse asInvalid(String message) { + return new ValidationResponse(false, message, null); + } + } + + /** + * Validates the structured content against the provided JSON schema. + * @param schema The JSON schema to validate against. + * @param structuredContent The structured content to validate. + * @return A ValidationResponse indicating whether the validation was successful or + * not. + */ + ValidationResponse validate(Map schema, Object structuredContent); + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java similarity index 87% rename from mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java rename to mcp-core/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java index 36aa18817..0ba7ab3b8 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java @@ -4,7 +4,7 @@ package io.modelcontextprotocol.spec; -import com.fasterxml.jackson.core.type.TypeReference; +import io.modelcontextprotocol.json.TypeRef; import io.modelcontextprotocol.util.Assert; import org.reactivestreams.Publisher; import org.slf4j.Logger; @@ -35,6 +35,7 @@ * * @author Christian Tzolov * @author Dariusz Jędrzejczyk + * @author Yanming Zhou */ public class McpClientSession implements McpSession { @@ -146,29 +147,46 @@ private void dismissPendingResponses() { private void handle(McpSchema.JSONRPCMessage message) { if (message instanceof McpSchema.JSONRPCResponse response) { - logger.debug("Received Response: {}", response); - var sink = pendingResponses.remove(response.id()); - if (sink == null) { - logger.warn("Unexpected response for unknown id {}", response.id()); + logger.debug("Received response: {}", response); + if (response.id() != null) { + var sink = pendingResponses.remove(response.id()); + if (sink == null) { + logger.warn("Unexpected response for unknown id {}", response.id()); + } + else { + sink.success(response); + } } else { - sink.success(response); + logger.error("Discarded MCP request response without session id. " + + "This is an indication of a bug in the request sender code that can lead to memory " + + "leaks as pending requests will never be completed."); } } else if (message instanceof McpSchema.JSONRPCRequest request) { logger.debug("Received request: {}", request); handleIncomingRequest(request).onErrorResume(error -> { + + McpSchema.JSONRPCResponse.JSONRPCError jsonRpcError = (error instanceof McpError mcpError + && mcpError.getJsonRpcError() != null) ? mcpError.getJsonRpcError() + // TODO: add error message through the data field + : new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, + error.getMessage(), McpError.aggregateExceptionMessages(error)); + var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, - new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, - error.getMessage(), null)); - return this.transport.sendMessage(errorResponse).then(Mono.empty()); - }).flatMap(this.transport::sendMessage).subscribe(); + jsonRpcError); + return Mono.just(errorResponse); + }).flatMap(this.transport::sendMessage).onErrorComplete(t -> { + logger.warn("Issue sending response to the client, ", t); + return true; + }).subscribe(); } else if (message instanceof McpSchema.JSONRPCNotification notification) { logger.debug("Received notification: {}", notification); - handleIncomingNotification(notification) - .doOnError(error -> logger.error("Error handling notification: {}", error.getMessage())) - .subscribe(); + handleIncomingNotification(notification).onErrorComplete(t -> { + logger.error("Error handling notification: {}", t.getMessage()); + return true; + }).subscribe(); } else { logger.warn("Received unknown message type: {}", message); @@ -191,11 +209,7 @@ private Mono handleIncomingRequest(McpSchema.JSONRPCR } return handler.handle(request.params()) - .map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), result, null)) - .onErrorResume(error -> Mono.just(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), - null, new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, - error.getMessage(), null)))); // TODO: add error message - // through the data field + .map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), result, null)); }); } @@ -221,7 +235,7 @@ private Mono handleIncomingNotification(McpSchema.JSONRPCNotification noti return Mono.defer(() -> { var handler = notificationHandlers.get(notification.method()); if (handler == null) { - logger.error("No handler registered for notification method: {}", notification.method()); + logger.warn("No handler registered for notification method: {}", notification); return Mono.empty(); } return handler.handle(notification.params()); @@ -246,7 +260,7 @@ private String generateRequestId() { * @return A Mono containing the response */ @Override - public Mono sendRequest(String method, Object requestParams, TypeReference typeRef) { + public Mono sendRequest(String method, Object requestParams, TypeRef typeRef) { String requestId = this.generateRequestId(); return Mono.deferContextual(ctx -> Mono.create(pendingResponseSink -> { diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java similarity index 99% rename from mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java rename to mcp-core/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java index 5c3b33131..22aec831b 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java @@ -1,6 +1,7 @@ /* * Copyright 2024 - 2024 the original author or authors. */ + package io.modelcontextprotocol.spec; import java.util.function.Consumer; diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpError.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpError.java new file mode 100644 index 000000000..d6e549fdc --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpError.java @@ -0,0 +1,116 @@ +/* +* Copyright 2024 - 2024 the original author or authors. +*/ + +package io.modelcontextprotocol.spec; + +import io.modelcontextprotocol.spec.McpSchema.JSONRPCResponse.JSONRPCError; +import io.modelcontextprotocol.util.Assert; + +import java.util.Map; +import java.util.function.Function; + +public class McpError extends RuntimeException { + + /** + * Resource + * Error Handling + */ + public static final Function RESOURCE_NOT_FOUND = resourceUri -> new McpError(new JSONRPCError( + McpSchema.ErrorCodes.RESOURCE_NOT_FOUND, "Resource not found", Map.of("uri", resourceUri))); + + private JSONRPCError jsonRpcError; + + public McpError(JSONRPCError jsonRpcError) { + super(jsonRpcError.message()); + this.jsonRpcError = jsonRpcError; + } + + @Deprecated + public McpError(Object error) { + super(error.toString()); + } + + public JSONRPCError getJsonRpcError() { + return jsonRpcError; + } + + @Override + public String toString() { + var builder = new StringBuilder(super.toString()); + if (jsonRpcError != null) { + builder.append("\n"); + builder.append(jsonRpcError.toString()); + } + return builder.toString(); + } + + public static Builder builder(int errorCode) { + return new Builder(errorCode); + } + + public static class Builder { + + private final int code; + + private String message; + + private Object data; + + private Builder(int code) { + this.code = code; + } + + public Builder message(String message) { + this.message = message; + return this; + } + + public Builder data(Object data) { + this.data = data; + return this; + } + + public McpError build() { + Assert.hasText(message, "message must not be empty"); + return new McpError(new JSONRPCError(code, message, data)); + } + + } + + public static Throwable findRootCause(Throwable throwable) { + Assert.notNull(throwable, "throwable must not be null"); + Throwable rootCause = throwable; + while (rootCause.getCause() != null && rootCause.getCause() != rootCause) { + rootCause = rootCause.getCause(); + } + return rootCause; + } + + public static String aggregateExceptionMessages(Throwable throwable) { + Assert.notNull(throwable, "throwable must not be null"); + + StringBuilder messages = new StringBuilder(); + Throwable current = throwable; + + while (current != null) { + if (messages.length() > 0) { + messages.append("\n Caused by: "); + } + + messages.append(current.getClass().getSimpleName()); + if (current.getMessage() != null) { + messages.append(": ").append(current.getMessage()); + } + + if (current.getCause() == current) { + break; + } + current = current.getCause(); + } + + return messages.toString(); + } + +} \ No newline at end of file diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpLoggableSession.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpLoggableSession.java new file mode 100644 index 000000000..f43a2c1d9 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpLoggableSession.java @@ -0,0 +1,29 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.spec; + +/** + * An {@link McpSession} which is capable of processing logging notifications and keeping + * track of a min logging level. + * + * @author Dariusz Jędrzejczyk + */ +public interface McpLoggableSession extends McpSession { + + /** + * Set the minimum logging level for the client. Messages below this level will be + * filtered out. + * @param minLoggingLevel The minimum logging level + */ + void setMinLoggingLevel(McpSchema.LoggingLevel minLoggingLevel); + + /** + * Allows checking whether a particular logging level is allowed. + * @param loggingLevel the level to check + * @return whether the logging at the specified level is permitted. + */ + boolean isNotificationForLevelAllowed(McpSchema.LoggingLevel loggingLevel); + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpSchema.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpSchema.java new file mode 100644 index 000000000..b58f1c552 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpSchema.java @@ -0,0 +1,2931 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.spec; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonSubTypes; +import com.fasterxml.jackson.annotation.JsonTypeInfo; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.TypeRef; +import io.modelcontextprotocol.util.Assert; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Based on the JSON-RPC 2.0 + * specification and the Model + * Context Protocol Schema. + * + * @author Christian Tzolov + * @author Luca Chang + * @author Surbhi Bansal + * @author Anurag Pant + */ +public final class McpSchema { + + private static final Logger logger = LoggerFactory.getLogger(McpSchema.class); + + private McpSchema() { + } + + @Deprecated + public static final String LATEST_PROTOCOL_VERSION = ProtocolVersions.MCP_2025_06_18; + + public static final String JSONRPC_VERSION = "2.0"; + + public static final String FIRST_PAGE = null; + + // --------------------------- + // Method Names + // --------------------------- + + // Lifecycle Methods + public static final String METHOD_INITIALIZE = "initialize"; + + public static final String METHOD_NOTIFICATION_INITIALIZED = "notifications/initialized"; + + public static final String METHOD_PING = "ping"; + + public static final String METHOD_NOTIFICATION_PROGRESS = "notifications/progress"; + + // Tool Methods + public static final String METHOD_TOOLS_LIST = "tools/list"; + + public static final String METHOD_TOOLS_CALL = "tools/call"; + + public static final String METHOD_NOTIFICATION_TOOLS_LIST_CHANGED = "notifications/tools/list_changed"; + + // Resources Methods + public static final String METHOD_RESOURCES_LIST = "resources/list"; + + public static final String METHOD_RESOURCES_READ = "resources/read"; + + public static final String METHOD_NOTIFICATION_RESOURCES_LIST_CHANGED = "notifications/resources/list_changed"; + + public static final String METHOD_NOTIFICATION_RESOURCES_UPDATED = "notifications/resources/updated"; + + public static final String METHOD_RESOURCES_TEMPLATES_LIST = "resources/templates/list"; + + public static final String METHOD_RESOURCES_SUBSCRIBE = "resources/subscribe"; + + public static final String METHOD_RESOURCES_UNSUBSCRIBE = "resources/unsubscribe"; + + // Prompt Methods + public static final String METHOD_PROMPT_LIST = "prompts/list"; + + public static final String METHOD_PROMPT_GET = "prompts/get"; + + public static final String METHOD_NOTIFICATION_PROMPTS_LIST_CHANGED = "notifications/prompts/list_changed"; + + public static final String METHOD_COMPLETION_COMPLETE = "completion/complete"; + + // Logging Methods + public static final String METHOD_LOGGING_SET_LEVEL = "logging/setLevel"; + + public static final String METHOD_NOTIFICATION_MESSAGE = "notifications/message"; + + // Roots Methods + public static final String METHOD_ROOTS_LIST = "roots/list"; + + public static final String METHOD_NOTIFICATION_ROOTS_LIST_CHANGED = "notifications/roots/list_changed"; + + // Sampling Methods + public static final String METHOD_SAMPLING_CREATE_MESSAGE = "sampling/createMessage"; + + // Elicitation Methods + public static final String METHOD_ELICITATION_CREATE = "elicitation/create"; + + // --------------------------- + // JSON-RPC Error Codes + // --------------------------- + /** + * Standard error codes used in MCP JSON-RPC responses. + */ + public static final class ErrorCodes { + + /** + * Invalid JSON was received by the server. + */ + public static final int PARSE_ERROR = -32700; + + /** + * The JSON sent is not a valid Request object. + */ + public static final int INVALID_REQUEST = -32600; + + /** + * The method does not exist / is not available. + */ + public static final int METHOD_NOT_FOUND = -32601; + + /** + * Invalid method parameter(s). + */ + public static final int INVALID_PARAMS = -32602; + + /** + * Internal JSON-RPC error. + */ + public static final int INTERNAL_ERROR = -32603; + + /** + * Resource not found. + */ + public static final int RESOURCE_NOT_FOUND = -32002; + + } + + /** + * Base interface for MCP objects that include optional metadata in the `_meta` field. + */ + public interface Meta { + + /** + * @see Specification + * for notes on _meta usage + * @return additional metadata related to this resource. + */ + Map meta(); + + } + + public sealed interface Request extends Meta + permits InitializeRequest, CallToolRequest, CreateMessageRequest, ElicitRequest, CompleteRequest, + GetPromptRequest, ReadResourceRequest, SubscribeRequest, UnsubscribeRequest, PaginatedRequest { + + default Object progressToken() { + if (meta() != null && meta().containsKey("progressToken")) { + return meta().get("progressToken"); + } + return null; + } + + } + + public sealed interface Result extends Meta permits InitializeResult, ListResourcesResult, + ListResourceTemplatesResult, ReadResourceResult, ListPromptsResult, GetPromptResult, ListToolsResult, + CallToolResult, CreateMessageResult, ElicitResult, CompleteResult, ListRootsResult { + + } + + public sealed interface Notification extends Meta + permits ProgressNotification, LoggingMessageNotification, ResourcesUpdatedNotification { + + } + + private static final TypeRef> MAP_TYPE_REF = new TypeRef<>() { + }; + + /** + * Deserializes a JSON string into a JSONRPCMessage object. + * @param jsonMapper The JsonMapper instance to use for deserialization + * @param jsonText The JSON string to deserialize + * @return A JSONRPCMessage instance using either the {@link JSONRPCRequest}, + * {@link JSONRPCNotification}, or {@link JSONRPCResponse} classes. + * @throws IOException If there's an error during deserialization + * @throws IllegalArgumentException If the JSON structure doesn't match any known + * message type + */ + public static JSONRPCMessage deserializeJsonRpcMessage(McpJsonMapper jsonMapper, String jsonText) + throws IOException { + + logger.debug("Received JSON message: {}", jsonText); + + var map = jsonMapper.readValue(jsonText, MAP_TYPE_REF); + + // Determine message type based on specific JSON structure + if (map.containsKey("method") && map.containsKey("id")) { + return jsonMapper.convertValue(map, JSONRPCRequest.class); + } + else if (map.containsKey("method") && !map.containsKey("id")) { + return jsonMapper.convertValue(map, JSONRPCNotification.class); + } + else if (map.containsKey("result") || map.containsKey("error")) { + return jsonMapper.convertValue(map, JSONRPCResponse.class); + } + + throw new IllegalArgumentException("Cannot deserialize JSONRPCMessage: " + jsonText); + } + + // --------------------------- + // JSON-RPC Message Types + // --------------------------- + public sealed interface JSONRPCMessage permits JSONRPCRequest, JSONRPCNotification, JSONRPCResponse { + + String jsonrpc(); + + } + + /** + * A request that expects a response. + * + * @param jsonrpc The JSON-RPC version (must be "2.0") + * @param method The name of the method to be invoked + * @param id A unique identifier for the request + * @param params Parameters for the method call + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + // @JsonFormat(with = JsonFormat.Feature.ACCEPT_SINGLE_VALUE_AS_ARRAY) + public record JSONRPCRequest( // @formatter:off + @JsonProperty("jsonrpc") String jsonrpc, + @JsonProperty("method") String method, + @JsonProperty("id") Object id, + @JsonProperty("params") Object params) implements JSONRPCMessage { // @formatter:on + + /** + * Constructor that validates MCP-specific ID requirements. Unlike base JSON-RPC, + * MCP requires that: (1) Requests MUST include a string or integer ID; (2) The ID + * MUST NOT be null + */ + public JSONRPCRequest { + Assert.notNull(id, "MCP requests MUST include an ID - null IDs are not allowed"); + Assert.isTrue(id instanceof String || id instanceof Integer || id instanceof Long, + "MCP requests MUST have an ID that is either a string or integer"); + } + } + + /** + * A notification which does not expect a response. + * + * @param jsonrpc The JSON-RPC version (must be "2.0") + * @param method The name of the method being notified + * @param params Parameters for the notification + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + // TODO: batching support + // @JsonFormat(with = JsonFormat.Feature.ACCEPT_SINGLE_VALUE_AS_ARRAY) + public record JSONRPCNotification( // @formatter:off + @JsonProperty("jsonrpc") String jsonrpc, + @JsonProperty("method") String method, + @JsonProperty("params") Object params) implements JSONRPCMessage { // @formatter:on + } + + /** + * A response to a request (successful, or error). + * + * @param jsonrpc The JSON-RPC version (must be "2.0") + * @param id The request identifier that this response corresponds to + * @param result The result of the successful request; null if error + * @param error Error information if the request failed; null if has result + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + // TODO: batching support + // @JsonFormat(with = JsonFormat.Feature.ACCEPT_SINGLE_VALUE_AS_ARRAY) + public record JSONRPCResponse( // @formatter:off + @JsonProperty("jsonrpc") String jsonrpc, + @JsonProperty("id") Object id, + @JsonProperty("result") Object result, + @JsonProperty("error") JSONRPCError error) implements JSONRPCMessage { // @formatter:on + + /** + * A response to a request that indicates an error occurred. + * + * @param code The error type that occurred + * @param message A short description of the error. The message SHOULD be limited + * to a concise single sentence + * @param data Additional information about the error. The value of this member is + * defined by the sender (e.g. detailed error information, nested errors etc.) + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record JSONRPCError( // @formatter:off + @JsonProperty("code") Integer code, + @JsonProperty("message") String message, + @JsonProperty("data") Object data) { // @formatter:on + } + } + + // --------------------------- + // Initialization + // --------------------------- + /** + * This request is sent from the client to the server when it first connects, asking + * it to begin initialization. + * + * @param protocolVersion The latest version of the Model Context Protocol that the + * client supports. The client MAY decide to support older versions as well + * @param capabilities The capabilities that the client supports + * @param clientInfo Information about the client implementation + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record InitializeRequest( // @formatter:off + @JsonProperty("protocolVersion") String protocolVersion, + @JsonProperty("capabilities") ClientCapabilities capabilities, + @JsonProperty("clientInfo") Implementation clientInfo, + @JsonProperty("_meta") Map meta) implements Request { // @formatter:on + + public InitializeRequest(String protocolVersion, ClientCapabilities capabilities, Implementation clientInfo) { + this(protocolVersion, capabilities, clientInfo, null); + } + } + + /** + * After receiving an initialize request from the client, the server sends this + * response. + * + * @param protocolVersion The version of the Model Context Protocol that the server + * wants to use. This may not match the version that the client requested. If the + * client cannot support this version, it MUST disconnect + * @param capabilities The capabilities that the server supports + * @param serverInfo Information about the server implementation + * @param instructions Instructions describing how to use the server and its features. + * This can be used by clients to improve the LLM's understanding of available tools, + * resources, etc. It can be thought of like a "hint" to the model. For example, this + * information MAY be added to the system prompt + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record InitializeResult( // @formatter:off + @JsonProperty("protocolVersion") String protocolVersion, + @JsonProperty("capabilities") ServerCapabilities capabilities, + @JsonProperty("serverInfo") Implementation serverInfo, + @JsonProperty("instructions") String instructions, + @JsonProperty("_meta") Map meta) implements Result { // @formatter:on + + public InitializeResult(String protocolVersion, ServerCapabilities capabilities, Implementation serverInfo, + String instructions) { + this(protocolVersion, capabilities, serverInfo, instructions, null); + } + } + + /** + * Capabilities a client may support. Known capabilities are defined here, in this + * schema, but this is not a closed set: any client can define its own, additional + * capabilities. + * + * @param experimental Experimental, non-standard capabilities that the client + * supports + * @param roots Present if the client supports listing roots + * @param sampling Present if the client supports sampling from an LLM + * @param elicitation Present if the client supports elicitation from the server + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ClientCapabilities( // @formatter:off + @JsonProperty("experimental") Map experimental, + @JsonProperty("roots") RootCapabilities roots, + @JsonProperty("sampling") Sampling sampling, + @JsonProperty("elicitation") Elicitation elicitation) { // @formatter:on + + /** + * Present if the client supports listing roots. + * + * @param listChanged Whether the client supports notifications for changes to the + * roots list + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record RootCapabilities(@JsonProperty("listChanged") Boolean listChanged) { + } + + /** + * Provides a standardized way for servers to request LLM sampling ("completions" + * or "generations") from language models via clients. This flow allows clients to + * maintain control over model access, selection, and permissions while enabling + * servers to leverage AI capabilities—with no server API keys necessary. Servers + * can request text or image-based interactions and optionally include context + * from MCP servers in their prompts. + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + public record Sampling() { + } + + /** + * Provides a standardized way for servers to request additional information from + * users through the client during interactions. This flow allows clients to + * maintain control over user interactions and data sharing while enabling servers + * to gather necessary information dynamically. Servers can request structured + * data from users with optional JSON schemas to validate responses. + * + *

+ * Per the 2025-11-25 spec, clients can declare support for specific elicitation + * modes: + *

    + *
  • {@code form} - In-band structured data collection with optional schema + * validation
  • + *
  • {@code url} - Out-of-band interaction via URL navigation
  • + *
+ * + *

+ * For backward compatibility, an empty elicitation object {@code {}} is + * equivalent to declaring support for form mode only. + * + * @param form support for in-band form-based elicitation + * @param url support for out-of-band URL-based elicitation + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + public record Elicitation(@JsonProperty("form") Form form, @JsonProperty("url") Url url) { + + /** + * Marker record indicating support for form-based elicitation mode. + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + public record Form() { + } + + /** + * Marker record indicating support for URL-based elicitation mode. + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + public record Url() { + } + + /** + * Creates an Elicitation with default settings (backward compatible, produces + * empty JSON object). + */ + public Elicitation() { + this(null, null); + } + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private Map experimental; + + private RootCapabilities roots; + + private Sampling sampling; + + private Elicitation elicitation; + + public Builder experimental(Map experimental) { + this.experimental = experimental; + return this; + } + + public Builder roots(Boolean listChanged) { + this.roots = new RootCapabilities(listChanged); + return this; + } + + public Builder sampling() { + this.sampling = new Sampling(); + return this; + } + + /** + * Enables elicitation capability with default settings (backward compatible, + * produces empty JSON object). + * @return this builder + */ + public Builder elicitation() { + this.elicitation = new Elicitation(); + return this; + } + + /** + * Enables elicitation capability with explicit form and/or url mode support. + * @param form whether to support form-based elicitation + * @param url whether to support URL-based elicitation + * @return this builder + */ + public Builder elicitation(boolean form, boolean url) { + this.elicitation = new Elicitation(form ? new Elicitation.Form() : null, + url ? new Elicitation.Url() : null); + return this; + } + + public ClientCapabilities build() { + return new ClientCapabilities(experimental, roots, sampling, elicitation); + } + + } + } + + /** + * Capabilities that a server may support. Known capabilities are defined here, in + * this schema, but this is not a closed set: any server can define its own, + * additional capabilities. + * + * @param completions Present if the server supports argument autocompletion + * suggestions + * @param experimental Experimental, non-standard capabilities that the server + * supports + * @param logging Present if the server supports sending log messages to the client + * @param prompts Present if the server offers any prompt templates + * @param resources Present if the server offers any resources to read + * @param tools Present if the server offers any tools to call + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ServerCapabilities( // @formatter:off + @JsonProperty("completions") CompletionCapabilities completions, + @JsonProperty("experimental") Map experimental, + @JsonProperty("logging") LoggingCapabilities logging, + @JsonProperty("prompts") PromptCapabilities prompts, + @JsonProperty("resources") ResourceCapabilities resources, + @JsonProperty("tools") ToolCapabilities tools) { // @formatter:on + + /** + * Present if the server supports argument autocompletion suggestions. + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + public record CompletionCapabilities() { + } + + /** + * Present if the server supports sending log messages to the client. + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + public record LoggingCapabilities() { + } + + /** + * Present if the server offers any prompt templates. + * + * @param listChanged Whether this server supports notifications for changes to + * the prompt list + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + public record PromptCapabilities(@JsonProperty("listChanged") Boolean listChanged) { + } + + /** + * Present if the server offers any resources to read. + * + * @param subscribe Whether this server supports subscribing to resource updates + * @param listChanged Whether this server supports notifications for changes to + * the resource list + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + public record ResourceCapabilities(@JsonProperty("subscribe") Boolean subscribe, + @JsonProperty("listChanged") Boolean listChanged) { + } + + /** + * Present if the server offers any tools to call. + * + * @param listChanged Whether this server supports notifications for changes to + * the tool list + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + public record ToolCapabilities(@JsonProperty("listChanged") Boolean listChanged) { + } + + /** + * Create a mutated copy of this object with the specified changes. + * @return A new Builder instance with the same values as this object. + */ + public Builder mutate() { + var builder = new Builder(); + builder.completions = this.completions; + builder.experimental = this.experimental; + builder.logging = this.logging; + builder.prompts = this.prompts; + builder.resources = this.resources; + builder.tools = this.tools; + return builder; + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private CompletionCapabilities completions; + + private Map experimental; + + private LoggingCapabilities logging; + + private PromptCapabilities prompts; + + private ResourceCapabilities resources; + + private ToolCapabilities tools; + + public Builder completions() { + this.completions = new CompletionCapabilities(); + return this; + } + + public Builder experimental(Map experimental) { + this.experimental = experimental; + return this; + } + + public Builder logging() { + this.logging = new LoggingCapabilities(); + return this; + } + + public Builder prompts(Boolean listChanged) { + this.prompts = new PromptCapabilities(listChanged); + return this; + } + + public Builder resources(Boolean subscribe, Boolean listChanged) { + this.resources = new ResourceCapabilities(subscribe, listChanged); + return this; + } + + public Builder tools(Boolean listChanged) { + this.tools = new ToolCapabilities(listChanged); + return this; + } + + public ServerCapabilities build() { + return new ServerCapabilities(completions, experimental, logging, prompts, resources, tools); + } + + } + } + + /** + * Describes the name and version of an MCP implementation, with an optional title for + * UI representation. + * + * @param name Intended for programmatic or logical use, but used as a display name in + * past specs or fallback (if title isn't present). + * @param title Intended for UI and end-user contexts + * @param version The version of the implementation. + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record Implementation( // @formatter:off + @JsonProperty("name") String name, + @JsonProperty("title") String title, + @JsonProperty("version") String version) implements Identifier { // @formatter:on + + public Implementation(String name, String version) { + this(name, null, version); + } + } + + // Existing Enums and Base Types (from previous implementation) + public enum Role { + + // @formatter:off + @JsonProperty("user") USER, + @JsonProperty("assistant") ASSISTANT + } // @formatter:on + + // --------------------------- + // Resource Interfaces + // --------------------------- + /** + * Base for objects that include optional annotations for the client. The client can + * use annotations to inform how objects are used or displayed + */ + public interface Annotated { + + Annotations annotations(); + + } + + /** + * Optional annotations for the client. The client can use annotations to inform how + * objects are used or displayed. + * + * @param audience Describes who the intended customer of this object or data is. It + * can include multiple entries to indicate content useful for multiple audiences + * (e.g., `["user", "assistant"]`). + * @param priority Describes how important this data is for operating the server. A + * value of 1 means "most important," and indicates that the data is effectively + * required, while 0 means "least important," and indicates that the data is entirely + * optional. It is a number between 0 and 1. + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record Annotations( // @formatter:off + @JsonProperty("audience") List audience, + @JsonProperty("priority") Double priority, + @JsonProperty("lastModified") String lastModified + ) { // @formatter:on + + public Annotations(List audience, Double priority) { + this(audience, priority, null); + } + } + + /** + * A common interface for resource content, which includes metadata about the resource + * such as its URI, name, description, MIME type, size, and annotations. This + * interface is implemented by both {@link Resource} and {@link ResourceLink} to + * provide a consistent way to access resource metadata. + */ + public interface ResourceContent extends Identifier, Annotated, Meta { + + // name & title from Identifier + + String uri(); + + String description(); + + String mimeType(); + + Long size(); + + // annotations from Annotated + // meta from Meta + + } + + /** + * Base interface with name (identifier) and title (display name) properties. + */ + public interface Identifier { + + /** + * Intended for programmatic or logical use, but used as a display name in past + * specs or fallback (if title isn't present). + */ + String name(); + + /** + * Intended for UI and end-user contexts — optimized to be human-readable and + * easily understood, even by those unfamiliar with domain-specific terminology. + * + * If not provided, the name should be used for display. + */ + String title(); + + } + + /** + * A known resource that the server is capable of reading. + * + * @param uri the URI of the resource. + * @param name A human-readable name for this resource. This can be used by clients to + * populate UI elements. + * @param title An optional title for this resource. + * @param description A description of what this resource represents. This can be used + * by clients to improve the LLM's understanding of available resources. It can be + * thought of like a "hint" to the model. + * @param mimeType The MIME type of this resource, if known. + * @param size The size of the raw resource content, in bytes (i.e., before base64 + * encoding or any tokenization), if known. This can be used by Hosts to display file + * sizes and estimate context window usage. + * @param annotations Optional annotations for the client. The client can use + * annotations to inform how objects are used or displayed. + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record Resource( // @formatter:off + @JsonProperty("uri") String uri, + @JsonProperty("name") String name, + @JsonProperty("title") String title, + @JsonProperty("description") String description, + @JsonProperty("mimeType") String mimeType, + @JsonProperty("size") Long size, + @JsonProperty("annotations") Annotations annotations, + @JsonProperty("_meta") Map meta) implements ResourceContent { // @formatter:on + + /** + * @deprecated Only exists for backwards-compatibility purposes. Use + * {@link Resource#builder()} instead. + */ + @Deprecated + public Resource(String uri, String name, String title, String description, String mimeType, Long size, + Annotations annotations) { + this(uri, name, title, description, mimeType, size, annotations, null); + } + + /** + * @deprecated Only exists for backwards-compatibility purposes. Use + * {@link Resource#builder()} instead. + */ + @Deprecated + public Resource(String uri, String name, String description, String mimeType, Long size, + Annotations annotations) { + this(uri, name, null, description, mimeType, size, annotations, null); + } + + /** + * @deprecated Only exists for backwards-compatibility purposes. Use + * {@link Resource#builder()} instead. + */ + @Deprecated + public Resource(String uri, String name, String description, String mimeType, Annotations annotations) { + this(uri, name, null, description, mimeType, null, annotations, null); + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private String uri; + + private String name; + + private String title; + + private String description; + + private String mimeType; + + private Long size; + + private Annotations annotations; + + private Map meta; + + public Builder uri(String uri) { + this.uri = uri; + return this; + } + + public Builder name(String name) { + this.name = name; + return this; + } + + public Builder title(String title) { + this.title = title; + return this; + } + + public Builder description(String description) { + this.description = description; + return this; + } + + public Builder mimeType(String mimeType) { + this.mimeType = mimeType; + return this; + } + + public Builder size(Long size) { + this.size = size; + return this; + } + + public Builder annotations(Annotations annotations) { + this.annotations = annotations; + return this; + } + + public Builder meta(Map meta) { + this.meta = meta; + return this; + } + + public Resource build() { + Assert.hasText(uri, "uri must not be empty"); + Assert.hasText(name, "name must not be empty"); + + return new Resource(uri, name, title, description, mimeType, size, annotations, meta); + } + + } + } + + /** + * Resource templates allow servers to expose parameterized resources using URI + * + * @param uriTemplate A URI template that can be used to generate URIs for this + * resource. + * @param name A human-readable name for this resource. This can be used by clients to + * populate UI elements. + * @param title An optional title for this resource. + * @param description A description of what this resource represents. This can be used + * by clients to improve the LLM's understanding of available resources. It can be + * thought of like a "hint" to the model. + * @param mimeType The MIME type of this resource, if known. + * @param annotations Optional annotations for the client. The client can use + * annotations to inform how objects are used or displayed. + * @see RFC 6570 + * @param meta See specification for notes on _meta usage + * + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ResourceTemplate( // @formatter:off + @JsonProperty("uriTemplate") String uriTemplate, + @JsonProperty("name") String name, + @JsonProperty("title") String title, + @JsonProperty("description") String description, + @JsonProperty("mimeType") String mimeType, + @JsonProperty("annotations") Annotations annotations, + @JsonProperty("_meta") Map meta) implements Annotated, Identifier, Meta { // @formatter:on + + public ResourceTemplate(String uriTemplate, String name, String title, String description, String mimeType, + Annotations annotations) { + this(uriTemplate, name, title, description, mimeType, annotations, null); + } + + public ResourceTemplate(String uriTemplate, String name, String description, String mimeType, + Annotations annotations) { + this(uriTemplate, name, null, description, mimeType, annotations); + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private String uriTemplate; + + private String name; + + private String title; + + private String description; + + private String mimeType; + + private Annotations annotations; + + private Map meta; + + public Builder uriTemplate(String uri) { + this.uriTemplate = uri; + return this; + } + + public Builder name(String name) { + this.name = name; + return this; + } + + public Builder title(String title) { + this.title = title; + return this; + } + + public Builder description(String description) { + this.description = description; + return this; + } + + public Builder mimeType(String mimeType) { + this.mimeType = mimeType; + return this; + } + + public Builder annotations(Annotations annotations) { + this.annotations = annotations; + return this; + } + + public Builder meta(Map meta) { + this.meta = meta; + return this; + } + + public ResourceTemplate build() { + Assert.hasText(uriTemplate, "uri must not be empty"); + Assert.hasText(name, "name must not be empty"); + + return new ResourceTemplate(uriTemplate, name, title, description, mimeType, annotations, meta); + } + + } + } + + /** + * The server's response to a resources/list request from the client. + * + * @param resources A list of resources that the server provides + * @param nextCursor An opaque token representing the pagination position after the + * last returned result. If present, there may be more results available + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ListResourcesResult( // @formatter:off + @JsonProperty("resources") List resources, + @JsonProperty("nextCursor") String nextCursor, + @JsonProperty("_meta") Map meta) implements Result { // @formatter:on + + public ListResourcesResult(List resources, String nextCursor) { + this(resources, nextCursor, null); + } + } + + /** + * The server's response to a resources/templates/list request from the client. + * + * @param resourceTemplates A list of resource templates that the server provides + * @param nextCursor An opaque token representing the pagination position after the + * last returned result. If present, there may be more results available + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ListResourceTemplatesResult( // @formatter:off + @JsonProperty("resourceTemplates") List resourceTemplates, + @JsonProperty("nextCursor") String nextCursor, + @JsonProperty("_meta") Map meta) implements Result { // @formatter:on + + public ListResourceTemplatesResult(List resourceTemplates, String nextCursor) { + this(resourceTemplates, nextCursor, null); + } + } + + /** + * Sent from the client to the server, to read a specific resource URI. + * + * @param uri The URI of the resource to read. The URI can use any protocol; it is up + * to the server how to interpret it + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ReadResourceRequest( // @formatter:off + @JsonProperty("uri") String uri, + @JsonProperty("_meta") Map meta) implements Request { // @formatter:on + + public ReadResourceRequest(String uri) { + this(uri, null); + } + } + + /** + * The server's response to a resources/read request from the client. + * + * @param contents The contents of the resource + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ReadResourceResult( // @formatter:off + @JsonProperty("contents") List contents, + @JsonProperty("_meta") Map meta) implements Result { // @formatter:on + + public ReadResourceResult(List contents) { + this(contents, null); + } + } + + /** + * Sent from the client to request resources/updated notifications from the server + * whenever a particular resource changes. + * + * @param uri the URI of the resource to subscribe to. The URI can use any protocol; + * it is up to the server how to interpret it. + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record SubscribeRequest( // @formatter:off + @JsonProperty("uri") String uri, + @JsonProperty("_meta") Map meta) implements Request { // @formatter:on + + public SubscribeRequest(String uri) { + this(uri, null); + } + } + + /** + * Sent from the client to request cancellation of resources/updated notifications + * from the server. This should follow a previous resources/subscribe request. + * + * @param uri The URI of the resource to unsubscribe from + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record UnsubscribeRequest( // @formatter:off + @JsonProperty("uri") String uri, + @JsonProperty("_meta") Map meta) implements Request { // @formatter:on + + public UnsubscribeRequest(String uri) { + this(uri, null); + } + } + + /** + * The contents of a specific resource or sub-resource. + */ + @JsonTypeInfo(use = JsonTypeInfo.Id.DEDUCTION) + @JsonSubTypes({ @JsonSubTypes.Type(value = TextResourceContents.class), + @JsonSubTypes.Type(value = BlobResourceContents.class) }) + public sealed interface ResourceContents extends Meta permits TextResourceContents, BlobResourceContents { + + /** + * The URI of this resource. + * @return the URI of this resource. + */ + String uri(); + + /** + * The MIME type of this resource. + * @return the MIME type of this resource. + */ + String mimeType(); + + } + + /** + * Text contents of a resource. + * + * @param uri the URI of this resource. + * @param mimeType the MIME type of this resource. + * @param text the text of the resource. This must only be set if the resource can + * actually be represented as text (not binary data). + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record TextResourceContents( // @formatter:off + @JsonProperty("uri") String uri, + @JsonProperty("mimeType") String mimeType, + @JsonProperty("text") String text, + @JsonProperty("_meta") Map meta) implements ResourceContents { // @formatter:on + + public TextResourceContents(String uri, String mimeType, String text) { + this(uri, mimeType, text, null); + } + } + + /** + * Binary contents of a resource. + * + * @param uri the URI of this resource. + * @param mimeType the MIME type of this resource. + * @param blob a base64-encoded string representing the binary data of the resource. + * This must only be set if the resource can actually be represented as binary data + * (not text). + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record BlobResourceContents( // @formatter:off + @JsonProperty("uri") String uri, + @JsonProperty("mimeType") String mimeType, + @JsonProperty("blob") String blob, + @JsonProperty("_meta") Map meta) implements ResourceContents { // @formatter:on + + public BlobResourceContents(String uri, String mimeType, String blob) { + this(uri, mimeType, blob, null); + } + } + + // --------------------------- + // Prompt Interfaces + // --------------------------- + /** + * A prompt or prompt template that the server offers. + * + * @param name The name of the prompt or prompt template. + * @param title An optional title for the prompt. + * @param description An optional description of what this prompt provides. + * @param arguments A list of arguments to use for templating the prompt. + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record Prompt( // @formatter:off + @JsonProperty("name") String name, + @JsonProperty("title") String title, + @JsonProperty("description") String description, + @JsonProperty("arguments") List arguments, + @JsonProperty("_meta") Map meta) implements Identifier { // @formatter:on + + public Prompt(String name, String description, List arguments) { + this(name, null, description, arguments != null ? arguments : new ArrayList<>()); + } + + public Prompt(String name, String title, String description, List arguments) { + this(name, title, description, arguments != null ? arguments : new ArrayList<>(), null); + } + } + + /** + * Describes an argument that a prompt can accept. + * + * @param name The name of the argument. + * @param title An optional title for the argument, which can be used in UI + * @param description A human-readable description of the argument. + * @param required Whether this argument must be provided. + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record PromptArgument( // @formatter:off + @JsonProperty("name") String name, + @JsonProperty("title") String title, + @JsonProperty("description") String description, + @JsonProperty("required") Boolean required) implements Identifier { // @formatter:on + + public PromptArgument(String name, String description, Boolean required) { + this(name, null, description, required); + } + } + + /** + * Describes a message returned as part of a prompt. + * + * This is similar to `SamplingMessage`, but also supports the embedding of resources + * from the MCP server. + * + * @param role The sender or recipient of messages and data in a conversation. + * @param content The content of the message of type {@link Content}. + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record PromptMessage( // @formatter:off + @JsonProperty("role") Role role, + @JsonProperty("content") Content content) { // @formatter:on + } + + /** + * The server's response to a prompts/list request from the client. + * + * @param prompts A list of prompts that the server provides. + * @param nextCursor An optional cursor for pagination. If present, indicates there + * are more prompts available. + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ListPromptsResult( // @formatter:off + @JsonProperty("prompts") List prompts, + @JsonProperty("nextCursor") String nextCursor, + @JsonProperty("_meta") Map meta) implements Result { // @formatter:on + + public ListPromptsResult(List prompts, String nextCursor) { + this(prompts, nextCursor, null); + } + } + + /** + * Used by the client to get a prompt provided by the server. + * + * @param name The name of the prompt or prompt template. + * @param arguments Arguments to use for templating the prompt. + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record GetPromptRequest( // @formatter:off + @JsonProperty("name") String name, + @JsonProperty("arguments") Map arguments, + @JsonProperty("_meta") Map meta) implements Request { // @formatter:on + + public GetPromptRequest(String name, Map arguments) { + this(name, arguments, null); + } + } + + /** + * The server's response to a prompts/get request from the client. + * + * @param description An optional description for the prompt. + * @param messages A list of messages to display as part of the prompt. + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record GetPromptResult( // @formatter:off + @JsonProperty("description") String description, + @JsonProperty("messages") List messages, + @JsonProperty("_meta") Map meta) implements Result { // @formatter:on + + public GetPromptResult(String description, List messages) { + this(description, messages, null); + } + } + + // --------------------------- + // Tool Interfaces + // --------------------------- + /** + * The server's response to a tools/list request from the client. + * + * @param tools A list of tools that the server provides. + * @param nextCursor An optional cursor for pagination. If present, indicates there + * are more tools available. + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ListToolsResult( // @formatter:off + @JsonProperty("tools") List tools, + @JsonProperty("nextCursor") String nextCursor, + @JsonProperty("_meta") Map meta) implements Result { // @formatter:on + + public ListToolsResult(List tools, String nextCursor) { + this(tools, nextCursor, null); + } + } + + /** + * A JSON Schema object that describes the expected structure of arguments or output. + * + * @param type The type of the schema (e.g., "object") + * @param properties The properties of the schema object + * @param required List of required property names + * @param additionalProperties Whether additional properties are allowed + * @param defs Schema definitions using the newer $defs keyword + * @param definitions Schema definitions using the legacy definitions keyword + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record JsonSchema( // @formatter:off + @JsonProperty("type") String type, + @JsonProperty("properties") Map properties, + @JsonProperty("required") List required, + @JsonProperty("additionalProperties") Boolean additionalProperties, + @JsonProperty("$defs") Map defs, + @JsonProperty("definitions") Map definitions) { // @formatter:on + } + + /** + * Additional properties describing a Tool to clients. + * + * NOTE: all properties in ToolAnnotations are **hints**. They are not guaranteed to + * provide a faithful description of tool behavior (including descriptive properties + * like `title`). + * + * Clients should never make tool use decisions based on ToolAnnotations received from + * untrusted servers. + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ToolAnnotations( // @formatter:off + @JsonProperty("title") String title, + @JsonProperty("readOnlyHint") Boolean readOnlyHint, + @JsonProperty("destructiveHint") Boolean destructiveHint, + @JsonProperty("idempotentHint") Boolean idempotentHint, + @JsonProperty("openWorldHint") Boolean openWorldHint, + @JsonProperty("returnDirect") Boolean returnDirect) { // @formatter:on + } + + /** + * Represents a tool that the server provides. Tools enable servers to expose + * executable functionality to the system. Through these tools, you can interact with + * external systems, perform computations, and take actions in the real world. + * + * @param name A unique identifier for the tool. This name is used when calling the + * tool. + * @param title A human-readable title for the tool. + * @param description A human-readable description of what the tool does. This can be + * used by clients to improve the LLM's understanding of available tools. + * @param inputSchema A JSON Schema object that describes the expected structure of + * the arguments when calling this tool. This allows clients to validate tool + * @param outputSchema An optional JSON Schema object defining the structure of the + * tool's output returned in the structuredContent field of a CallToolResult. + * @param annotations Optional additional tool information. + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record Tool( // @formatter:off + @JsonProperty("name") String name, + @JsonProperty("title") String title, + @JsonProperty("description") String description, + @JsonProperty("inputSchema") JsonSchema inputSchema, + @JsonProperty("outputSchema") Map outputSchema, + @JsonProperty("annotations") ToolAnnotations annotations, + @JsonProperty("_meta") Map meta) { // @formatter:on + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private String name; + + private String title; + + private String description; + + private JsonSchema inputSchema; + + private Map outputSchema; + + private ToolAnnotations annotations; + + private Map meta; + + public Builder name(String name) { + this.name = name; + return this; + } + + public Builder title(String title) { + this.title = title; + return this; + } + + public Builder description(String description) { + this.description = description; + return this; + } + + public Builder inputSchema(JsonSchema inputSchema) { + this.inputSchema = inputSchema; + return this; + } + + public Builder inputSchema(McpJsonMapper jsonMapper, String inputSchema) { + this.inputSchema = parseSchema(jsonMapper, inputSchema); + return this; + } + + public Builder outputSchema(Map outputSchema) { + this.outputSchema = outputSchema; + return this; + } + + public Builder outputSchema(McpJsonMapper jsonMapper, String outputSchema) { + this.outputSchema = schemaToMap(jsonMapper, outputSchema); + return this; + } + + public Builder annotations(ToolAnnotations annotations) { + this.annotations = annotations; + return this; + } + + public Builder meta(Map meta) { + this.meta = meta; + return this; + } + + public Tool build() { + Assert.hasText(name, "name must not be empty"); + return new Tool(name, title, description, inputSchema, outputSchema, annotations, meta); + } + + } + } + + private static Map schemaToMap(McpJsonMapper jsonMapper, String schema) { + try { + return jsonMapper.readValue(schema, MAP_TYPE_REF); + } + catch (IOException e) { + throw new IllegalArgumentException("Invalid schema: " + schema, e); + } + } + + private static JsonSchema parseSchema(McpJsonMapper jsonMapper, String schema) { + try { + return jsonMapper.readValue(schema, JsonSchema.class); + } + catch (IOException e) { + throw new IllegalArgumentException("Invalid schema: " + schema, e); + } + } + + /** + * Used by the client to call a tool provided by the server. + * + * @param name The name of the tool to call. This must match a tool name from + * tools/list. + * @param arguments Arguments to pass to the tool. These must conform to the tool's + * input schema. + * @param meta Optional metadata about the request. This can include additional + * information like `progressToken` + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record CallToolRequest( // @formatter:off + @JsonProperty("name") String name, + @JsonProperty("arguments") Map arguments, + @JsonProperty("_meta") Map meta) implements Request { // @formatter:on + + public CallToolRequest(McpJsonMapper jsonMapper, String name, String jsonArguments) { + this(name, parseJsonArguments(jsonMapper, jsonArguments), null); + } + + public CallToolRequest(String name, Map arguments) { + this(name, arguments, null); + } + + private static Map parseJsonArguments(McpJsonMapper jsonMapper, String jsonArguments) { + try { + return jsonMapper.readValue(jsonArguments, MAP_TYPE_REF); + } + catch (IOException e) { + throw new IllegalArgumentException("Invalid arguments: " + jsonArguments, e); + } + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private String name; + + private Map arguments; + + private Map meta; + + public Builder name(String name) { + this.name = name; + return this; + } + + public Builder arguments(Map arguments) { + this.arguments = arguments; + return this; + } + + public Builder arguments(McpJsonMapper jsonMapper, String jsonArguments) { + this.arguments = parseJsonArguments(jsonMapper, jsonArguments); + return this; + } + + public Builder meta(Map meta) { + this.meta = meta; + return this; + } + + public Builder progressToken(Object progressToken) { + if (this.meta == null) { + this.meta = new HashMap<>(); + } + this.meta.put("progressToken", progressToken); + return this; + } + + public CallToolRequest build() { + Assert.hasText(name, "name must not be empty"); + return new CallToolRequest(name, arguments, meta); + } + + } + } + + /** + * The server's response to a tools/call request from the client. + * + * @param content A list of content items representing the tool's output. Each item + * can be text, an image, or an embedded resource. + * @param isError If true, indicates that the tool execution failed and the content + * contains error information. If false or absent, indicates successful execution. + * @param structuredContent An optional JSON object that represents the structured + * result of the tool call. + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record CallToolResult( // @formatter:off + @JsonProperty("content") List content, + @JsonProperty("isError") Boolean isError, + @JsonProperty("structuredContent") Object structuredContent, + @JsonProperty("_meta") Map meta) implements Result { // @formatter:on + + /** + * @deprecated use the builder instead. + */ + @Deprecated + public CallToolResult(List content, Boolean isError) { + this(content, isError, (Object) null, null); + } + + /** + * @deprecated use the builder instead. + */ + @Deprecated + public CallToolResult(List content, Boolean isError, Map structuredContent) { + this(content, isError, structuredContent, null); + } + + /** + * Creates a new instance of {@link CallToolResult} with a string containing the + * tool result. + * @param content The content of the tool result. This will be mapped to a + * one-sized list with a {@link TextContent} element. + * @param isError If true, indicates that the tool execution failed and the + * content contains error information. If false or absent, indicates successful + * execution. + */ + @Deprecated + public CallToolResult(String content, Boolean isError) { + this(List.of(new TextContent(content)), isError, null); + } + + /** + * Creates a builder for {@link CallToolResult}. + * @return a new builder instance + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for {@link CallToolResult}. + */ + public static class Builder { + + private List content = new ArrayList<>(); + + private Boolean isError = false; + + private Object structuredContent; + + private Map meta; + + /** + * Sets the content list for the tool result. + * @param content the content list + * @return this builder + */ + public Builder content(List content) { + Assert.notNull(content, "content must not be null"); + this.content = content; + return this; + } + + public Builder structuredContent(Object structuredContent) { + Assert.notNull(structuredContent, "structuredContent must not be null"); + this.structuredContent = structuredContent; + return this; + } + + public Builder structuredContent(McpJsonMapper jsonMapper, String structuredContent) { + Assert.hasText(structuredContent, "structuredContent must not be empty"); + try { + this.structuredContent = jsonMapper.readValue(structuredContent, MAP_TYPE_REF); + } + catch (IOException e) { + throw new IllegalArgumentException("Invalid structured content: " + structuredContent, e); + } + return this; + } + + /** + * Sets the text content for the tool result. + * @param textContent the text content + * @return this builder + */ + public Builder textContent(List textContent) { + Assert.notNull(textContent, "textContent must not be null"); + textContent.stream().map(TextContent::new).forEach(this.content::add); + return this; + } + + /** + * Adds a content item to the tool result. + * @param contentItem the content item to add + * @return this builder + */ + public Builder addContent(Content contentItem) { + Assert.notNull(contentItem, "contentItem must not be null"); + if (this.content == null) { + this.content = new ArrayList<>(); + } + this.content.add(contentItem); + return this; + } + + /** + * Adds a text content item to the tool result. + * @param text the text content + * @return this builder + */ + public Builder addTextContent(String text) { + Assert.notNull(text, "text must not be null"); + return addContent(new TextContent(text)); + } + + /** + * Sets whether the tool execution resulted in an error. + * @param isError true if the tool execution failed, false otherwise + * @return this builder + */ + public Builder isError(Boolean isError) { + Assert.notNull(isError, "isError must not be null"); + this.isError = isError; + return this; + } + + /** + * Sets the metadata for the tool result. + * @param meta metadata + * @return this builder + */ + public Builder meta(Map meta) { + this.meta = meta; + return this; + } + + /** + * Builds a new {@link CallToolResult} instance. + * @return a new CallToolResult instance + */ + public CallToolResult build() { + return new CallToolResult(content, isError, structuredContent, meta); + } + + } + + } + + // --------------------------- + // Sampling Interfaces + // --------------------------- + /** + * The server's preferences for model selection, requested of the client during + * sampling. + * + * @param hints Optional hints to use for model selection. If multiple hints are + * specified, the client MUST evaluate them in order (such that the first match is + * taken). The client SHOULD prioritize these hints over the numeric priorities, but + * MAY still use the priorities to select from ambiguous matches + * @param costPriority How much to prioritize cost when selecting a model. A value of + * 0 means cost is not important, while a value of 1 means cost is the most important + * factor + * @param speedPriority How much to prioritize sampling speed (latency) when selecting + * a model. A value of 0 means speed is not important, while a value of 1 means speed + * is the most important factor + * @param intelligencePriority How much to prioritize intelligence and capabilities + * when selecting a model. A value of 0 means intelligence is not important, while a + * value of 1 means intelligence is the most important factor + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ModelPreferences( // @formatter:off + @JsonProperty("hints") List hints, + @JsonProperty("costPriority") Double costPriority, + @JsonProperty("speedPriority") Double speedPriority, + @JsonProperty("intelligencePriority") Double intelligencePriority) { // @formatter:on + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private List hints; + + private Double costPriority; + + private Double speedPriority; + + private Double intelligencePriority; + + public Builder hints(List hints) { + this.hints = hints; + return this; + } + + public Builder addHint(String name) { + if (this.hints == null) { + this.hints = new ArrayList<>(); + } + this.hints.add(new ModelHint(name)); + return this; + } + + public Builder costPriority(Double costPriority) { + this.costPriority = costPriority; + return this; + } + + public Builder speedPriority(Double speedPriority) { + this.speedPriority = speedPriority; + return this; + } + + public Builder intelligencePriority(Double intelligencePriority) { + this.intelligencePriority = intelligencePriority; + return this; + } + + public ModelPreferences build() { + return new ModelPreferences(hints, costPriority, speedPriority, intelligencePriority); + } + + } + } + + /** + * Hints to use for model selection. + * + * @param name A hint for a model name. The client SHOULD treat this as a substring of + * a model name; for example: `claude-3-5-sonnet` should match + * `claude-3-5-sonnet-20241022`, `sonnet` should match `claude-3-5-sonnet-20241022`, + * `claude-3-sonnet-20240229`, etc., `claude` should match any Claude model. The + * client MAY also map the string to a different provider's model name or a different + * model family, as long as it fills a similar niche + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ModelHint(@JsonProperty("name") String name) { + public static ModelHint of(String name) { + return new ModelHint(name); + } + } + + /** + * Describes a message issued to or received from an LLM API. + * + * @param role The sender or recipient of messages and data in a conversation + * @param content The content of the message + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record SamplingMessage( // @formatter:off + @JsonProperty("role") Role role, + @JsonProperty("content") Content content) { // @formatter:on + } + + /** + * A request from the server to sample an LLM via the client. The client has full + * discretion over which model to select. The client should also inform the user + * before beginning sampling, to allow them to inspect the request (human in the loop) + * and decide whether to approve it. + * + * @param messages The conversation messages to send to the LLM + * @param modelPreferences The server's preferences for which model to select. The + * client MAY ignore these preferences + * @param systemPrompt An optional system prompt the server wants to use for sampling. + * The client MAY modify or omit this prompt + * @param includeContext A request to include context from one or more MCP servers + * (including the caller), to be attached to the prompt. The client MAY ignore this + * request + * @param temperature Optional temperature parameter for sampling + * @param maxTokens The maximum number of tokens to sample, as requested by the + * server. The client MAY choose to sample fewer tokens than requested + * @param stopSequences Optional stop sequences for sampling + * @param metadata Optional metadata to pass through to the LLM provider. The format + * of this metadata is provider-specific + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record CreateMessageRequest( // @formatter:off + @JsonProperty("messages") List messages, + @JsonProperty("modelPreferences") ModelPreferences modelPreferences, + @JsonProperty("systemPrompt") String systemPrompt, + @JsonProperty("includeContext") ContextInclusionStrategy includeContext, + @JsonProperty("temperature") Double temperature, + @JsonProperty("maxTokens") Integer maxTokens, + @JsonProperty("stopSequences") List stopSequences, + @JsonProperty("metadata") Map metadata, + @JsonProperty("_meta") Map meta) implements Request { // @formatter:on + + // backwards compatibility constructor + public CreateMessageRequest(List messages, ModelPreferences modelPreferences, + String systemPrompt, ContextInclusionStrategy includeContext, Double temperature, Integer maxTokens, + List stopSequences, Map metadata) { + this(messages, modelPreferences, systemPrompt, includeContext, temperature, maxTokens, stopSequences, + metadata, null); + } + + public enum ContextInclusionStrategy { + + // @formatter:off + @JsonProperty("none") NONE, + @JsonProperty("thisServer") THIS_SERVER, + @JsonProperty("allServers")ALL_SERVERS + } // @formatter:on + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private List messages; + + private ModelPreferences modelPreferences; + + private String systemPrompt; + + private ContextInclusionStrategy includeContext; + + private Double temperature; + + private Integer maxTokens; + + private List stopSequences; + + private Map metadata; + + private Map meta; + + public Builder messages(List messages) { + this.messages = messages; + return this; + } + + public Builder modelPreferences(ModelPreferences modelPreferences) { + this.modelPreferences = modelPreferences; + return this; + } + + public Builder systemPrompt(String systemPrompt) { + this.systemPrompt = systemPrompt; + return this; + } + + public Builder includeContext(ContextInclusionStrategy includeContext) { + this.includeContext = includeContext; + return this; + } + + public Builder temperature(Double temperature) { + this.temperature = temperature; + return this; + } + + public Builder maxTokens(int maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + public Builder stopSequences(List stopSequences) { + this.stopSequences = stopSequences; + return this; + } + + public Builder metadata(Map metadata) { + this.metadata = metadata; + return this; + } + + public Builder meta(Map meta) { + this.meta = meta; + return this; + } + + public Builder progressToken(Object progressToken) { + if (this.meta == null) { + this.meta = new HashMap<>(); + } + this.meta.put("progressToken", progressToken); + return this; + } + + public CreateMessageRequest build() { + return new CreateMessageRequest(messages, modelPreferences, systemPrompt, includeContext, temperature, + maxTokens, stopSequences, metadata, meta); + } + + } + } + + /** + * The client's response to a sampling/create_message request from the server. The + * client should inform the user before returning the sampled message, to allow them + * to inspect the response (human in the loop) and decide whether to allow the server + * to see it. + * + * @param role The role of the message sender (typically assistant) + * @param content The content of the sampled message + * @param model The name of the model that generated the message + * @param stopReason The reason why sampling stopped, if known + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record CreateMessageResult( // @formatter:off + @JsonProperty("role") Role role, + @JsonProperty("content") Content content, + @JsonProperty("model") String model, + @JsonProperty("stopReason") StopReason stopReason, + @JsonProperty("_meta") Map meta) implements Result { // @formatter:on + + public enum StopReason { + + // @formatter:off + @JsonProperty("endTurn") END_TURN("endTurn"), + @JsonProperty("stopSequence") STOP_SEQUENCE("stopSequence"), + @JsonProperty("maxTokens") MAX_TOKENS("maxTokens"), + @JsonProperty("unknown") UNKNOWN("unknown"); + // @formatter:on + + private final String value; + + StopReason(String value) { + this.value = value; + } + + @JsonCreator + private static StopReason of(String value) { + return Arrays.stream(StopReason.values()) + .filter(stopReason -> stopReason.value.equals(value)) + .findFirst() + .orElse(StopReason.UNKNOWN); + } + + } + + public CreateMessageResult(Role role, Content content, String model, StopReason stopReason) { + this(role, content, model, stopReason, null); + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private Role role = Role.ASSISTANT; + + private Content content; + + private String model; + + private StopReason stopReason = StopReason.END_TURN; + + private Map meta; + + public Builder role(Role role) { + this.role = role; + return this; + } + + public Builder content(Content content) { + this.content = content; + return this; + } + + public Builder model(String model) { + this.model = model; + return this; + } + + public Builder stopReason(StopReason stopReason) { + this.stopReason = stopReason; + return this; + } + + public Builder message(String message) { + this.content = new TextContent(message); + return this; + } + + public Builder meta(Map meta) { + this.meta = meta; + return this; + } + + public CreateMessageResult build() { + return new CreateMessageResult(role, content, model, stopReason, meta); + } + + } + } + + // Elicitation + /** + * A request from the server to elicit additional information from the user via the + * client. + * + * @param message The message to present to the user + * @param requestedSchema A restricted subset of JSON Schema. Only top-level + * properties are allowed, without nesting + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ElicitRequest( // @formatter:off + @JsonProperty("message") String message, + @JsonProperty("requestedSchema") Map requestedSchema, + @JsonProperty("_meta") Map meta) implements Request { // @formatter:on + + // backwards compatibility constructor + public ElicitRequest(String message, Map requestedSchema) { + this(message, requestedSchema, null); + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private String message; + + private Map requestedSchema; + + private Map meta; + + public Builder message(String message) { + this.message = message; + return this; + } + + public Builder requestedSchema(Map requestedSchema) { + this.requestedSchema = requestedSchema; + return this; + } + + public Builder meta(Map meta) { + this.meta = meta; + return this; + } + + public Builder progressToken(Object progressToken) { + if (this.meta == null) { + this.meta = new HashMap<>(); + } + this.meta.put("progressToken", progressToken); + return this; + } + + public ElicitRequest build() { + return new ElicitRequest(message, requestedSchema, meta); + } + + } + } + + /** + * The client's response to an elicitation request. + * + * @param action The user action in response to the elicitation. "accept": User + * submitted the form/confirmed the action, "decline": User explicitly declined the + * action, "cancel": User dismissed without making an explicit choice + * @param content The submitted form data, only present when action is "accept". + * Contains values matching the requested schema + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ElicitResult( // @formatter:off + @JsonProperty("action") Action action, + @JsonProperty("content") Map content, + @JsonProperty("_meta") Map meta) implements Result { // @formatter:on + + public enum Action { + + // @formatter:off + @JsonProperty("accept") ACCEPT, + @JsonProperty("decline") DECLINE, + @JsonProperty("cancel") CANCEL + } // @formatter:on + + // backwards compatibility constructor + public ElicitResult(Action action, Map content) { + this(action, content, null); + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private Action action; + + private Map content; + + private Map meta; + + public Builder message(Action action) { + this.action = action; + return this; + } + + public Builder content(Map content) { + this.content = content; + return this; + } + + public Builder meta(Map meta) { + this.meta = meta; + return this; + } + + public ElicitResult build() { + return new ElicitResult(action, content, meta); + } + + } + } + + // --------------------------- + // Pagination Interfaces + // --------------------------- + /** + * A request that supports pagination using cursors. + * + * @param cursor An opaque token representing the current pagination position. If + * provided, the server should return results starting after this cursor + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record PaginatedRequest( // @formatter:off + @JsonProperty("cursor") String cursor, + @JsonProperty("_meta") Map meta) implements Request { // @formatter:on + + public PaginatedRequest(String cursor) { + this(cursor, null); + } + + /** + * Creates a new paginated request with an empty cursor. + */ + public PaginatedRequest() { + this(null); + } + } + + /** + * An opaque token representing the pagination position after the last returned + * result. If present, there may be more results available. + * + * @param nextCursor An opaque token representing the pagination position after the + * last returned result. If present, there may be more results available + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record PaginatedResult(@JsonProperty("nextCursor") String nextCursor) { + } + + // --------------------------- + // Progress and Logging + // --------------------------- + /** + * The Model Context Protocol (MCP) supports optional progress tracking for + * long-running operations through notification messages. Either side can send + * progress notifications to provide updates about operation status. + * + * @param progressToken A unique token to identify the progress notification. MUST be + * unique across all active requests. + * @param progress A value indicating the current progress. + * @param total An optional total amount of work to be done, if known. + * @param message An optional message providing additional context about the progress. + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ProgressNotification( // @formatter:off + @JsonProperty("progressToken") Object progressToken, + @JsonProperty("progress") Double progress, + @JsonProperty("total") Double total, + @JsonProperty("message") String message, + @JsonProperty("_meta") Map meta) implements Notification { // @formatter:on + + public ProgressNotification(Object progressToken, double progress, Double total, String message) { + this(progressToken, progress, total, message, null); + } + } + + /** + * The Model Context Protocol (MCP) provides a standardized way for servers to send + * resources update message to clients. + * + * @param uri The updated resource uri. + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ResourcesUpdatedNotification(// @formatter:off + @JsonProperty("uri") String uri, + @JsonProperty("_meta") Map meta) implements Notification { // @formatter:on + + public ResourcesUpdatedNotification(String uri) { + this(uri, null); + } + } + + /** + * The Model Context Protocol (MCP) provides a standardized way for servers to send + * structured log messages to clients. Clients can control logging verbosity by + * setting minimum log levels, with servers sending notifications containing severity + * levels, optional logger names, and arbitrary JSON-serializable data. + * + * @param level The severity levels. The minimum log level is set by the client. + * @param logger The logger that generated the message. + * @param data JSON-serializable logging data. + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record LoggingMessageNotification( // @formatter:off + @JsonProperty("level") LoggingLevel level, + @JsonProperty("logger") String logger, + @JsonProperty("data") String data, + @JsonProperty("_meta") Map meta) implements Notification { // @formatter:on + + // backwards compatibility constructor + public LoggingMessageNotification(LoggingLevel level, String logger, String data) { + this(level, logger, data, null); + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private LoggingLevel level = LoggingLevel.INFO; + + private String logger = "server"; + + private String data; + + private Map meta; + + public Builder level(LoggingLevel level) { + this.level = level; + return this; + } + + public Builder logger(String logger) { + this.logger = logger; + return this; + } + + public Builder data(String data) { + this.data = data; + return this; + } + + public Builder meta(Map meta) { + this.meta = meta; + return this; + } + + public LoggingMessageNotification build() { + return new LoggingMessageNotification(level, logger, data, meta); + } + + } + } + + public enum LoggingLevel { + + // @formatter:off + @JsonProperty("debug") DEBUG(0), + @JsonProperty("info") INFO(1), + @JsonProperty("notice") NOTICE(2), + @JsonProperty("warning") WARNING(3), + @JsonProperty("error") ERROR(4), + @JsonProperty("critical") CRITICAL(5), + @JsonProperty("alert") ALERT(6), + @JsonProperty("emergency") EMERGENCY(7); + // @formatter:on + + private final int level; + + LoggingLevel(int level) { + this.level = level; + } + + public int level() { + return level; + } + + } + + /** + * A request from the client to the server, to enable or adjust logging. + * + * @param level The level of logging that the client wants to receive from the server. + * The server should send all logs at this level and higher (i.e., more severe) to the + * client as notifications/message + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record SetLevelRequest(@JsonProperty("level") LoggingLevel level) { + } + + // --------------------------- + // Autocomplete + // --------------------------- + public sealed interface CompleteReference permits PromptReference, ResourceReference { + + String type(); + + String identifier(); + + } + + /** + * Identifies a prompt for completion requests. + * + * @param type The reference type identifier (typically "ref/prompt") + * @param name The name of the prompt + * @param title An optional title for the prompt + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record PromptReference( // @formatter:off + @JsonProperty("type") String type, + @JsonProperty("name") String name, + @JsonProperty("title") String title ) implements McpSchema.CompleteReference, Identifier { // @formatter:on + + public static final String TYPE = "ref/prompt"; + + public PromptReference(String type, String name) { + this(type, name, null); + } + + public PromptReference(String name) { + this(TYPE, name, null); + } + + @Override + public String identifier() { + return name(); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null || getClass() != obj.getClass()) + return false; + PromptReference that = (PromptReference) obj; + return java.util.Objects.equals(identifier(), that.identifier()) + && java.util.Objects.equals(type(), that.type()); + } + + @Override + public int hashCode() { + return java.util.Objects.hash(identifier(), type()); + } + } + + /** + * A reference to a resource or resource template definition for completion requests. + * + * @param type The reference type identifier (typically "ref/resource") + * @param uri The URI or URI template of the resource + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ResourceReference( // @formatter:off + @JsonProperty("type") String type, + @JsonProperty("uri") String uri) implements McpSchema.CompleteReference { // @formatter:on + + public static final String TYPE = "ref/resource"; + + public ResourceReference(String uri) { + this(TYPE, uri); + } + + @Override + public String identifier() { + return uri(); + } + } + + /** + * A request from the client to the server, to ask for completion options. + * + * @param ref A reference to a prompt or resource template definition + * @param argument The argument's information for completion requests + * @param meta See specification for notes on _meta usage + * @param context Additional, optional context for completions + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record CompleteRequest( // @formatter:off + @JsonProperty("ref") McpSchema.CompleteReference ref, + @JsonProperty("argument") CompleteArgument argument, + @JsonProperty("_meta") Map meta, + @JsonProperty("context") CompleteContext context) implements Request { // @formatter:on + + public CompleteRequest(McpSchema.CompleteReference ref, CompleteArgument argument, Map meta) { + this(ref, argument, meta, null); + } + + public CompleteRequest(McpSchema.CompleteReference ref, CompleteArgument argument, CompleteContext context) { + this(ref, argument, null, context); + } + + public CompleteRequest(McpSchema.CompleteReference ref, CompleteArgument argument) { + this(ref, argument, null, null); + } + + /** + * The argument's information for completion requests. + * + * @param name The name of the argument + * @param value The value of the argument to use for completion matching + */ + public record CompleteArgument(@JsonProperty("name") String name, @JsonProperty("value") String value) { + } + + /** + * Additional, optional context for completions. + * + * @param arguments Previously-resolved variables in a URI template or prompt + */ + public record CompleteContext(@JsonProperty("arguments") Map arguments) { + } + } + + /** + * The server's response to a completion/complete request. + * + * @param completion The completion information containing values and metadata + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record CompleteResult(// @formatter:off + @JsonProperty("completion") CompleteCompletion completion, + @JsonProperty("_meta") Map meta) implements Result { // @formatter:on + + // backwards compatibility constructor + public CompleteResult(CompleteCompletion completion) { + this(completion, null); + } + + /** + * The server's response to a completion/complete request + * + * @param values An array of completion values. Must not exceed 100 items + * @param total The total number of completion options available. This can exceed + * the number of values actually sent in the response + * @param hasMore Indicates whether there are additional completion options beyond + * those provided in the current response, even if the exact total is unknown + */ + @JsonInclude(JsonInclude.Include.ALWAYS) + public record CompleteCompletion( // @formatter:off + @JsonProperty("values") List values, + @JsonProperty("total") Integer total, + @JsonProperty("hasMore") Boolean hasMore) { // @formatter:on + } + } + + // --------------------------- + // Content Types + // --------------------------- + @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.PROPERTY, property = "type") + @JsonSubTypes({ @JsonSubTypes.Type(value = TextContent.class, name = "text"), + @JsonSubTypes.Type(value = ImageContent.class, name = "image"), + @JsonSubTypes.Type(value = AudioContent.class, name = "audio"), + @JsonSubTypes.Type(value = EmbeddedResource.class, name = "resource"), + @JsonSubTypes.Type(value = ResourceLink.class, name = "resource_link") }) + public sealed interface Content extends Meta + permits TextContent, ImageContent, AudioContent, EmbeddedResource, ResourceLink { + + default String type() { + if (this instanceof TextContent) { + return "text"; + } + else if (this instanceof ImageContent) { + return "image"; + } + else if (this instanceof AudioContent) { + return "audio"; + } + else if (this instanceof EmbeddedResource) { + return "resource"; + } + else if (this instanceof ResourceLink) { + return "resource_link"; + } + throw new IllegalArgumentException("Unknown content type: " + this); + } + + } + + /** + * Text provided to or from an LLM. + * + * @param annotations Optional annotations for the client + * @param text The text content of the message + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record TextContent( // @formatter:off + @JsonProperty("annotations") Annotations annotations, + @JsonProperty("text") String text, + @JsonProperty("_meta") Map meta) implements Annotated, Content { // @formatter:on + + public TextContent(Annotations annotations, String text) { + this(annotations, text, null); + } + + public TextContent(String content) { + this(null, content, null); + } + + /** + * @deprecated Only exists for backwards-compatibility purposes. Use + * {@link TextContent#TextContent(Annotations, String)} instead. + */ + @Deprecated + public TextContent(List audience, Double priority, String content) { + this(audience != null || priority != null ? new Annotations(audience, priority) : null, content, null); + } + + /** + * @deprecated Only exists for backwards-compatibility purposes. Use + * {@link TextContent#annotations()} instead. + */ + @Deprecated + public List audience() { + return annotations == null ? null : annotations.audience(); + } + + /** + * @deprecated Only exists for backwards-compatibility purposes. Use + * {@link TextContent#annotations()} instead. + */ + @Deprecated + public Double priority() { + return annotations == null ? null : annotations.priority(); + } + } + + /** + * An image provided to or from an LLM. + * + * @param annotations Optional annotations for the client + * @param data The base64-encoded image data + * @param mimeType The MIME type of the image. Different providers may support + * different image types + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ImageContent( // @formatter:off + @JsonProperty("annotations") Annotations annotations, + @JsonProperty("data") String data, + @JsonProperty("mimeType") String mimeType, + @JsonProperty("_meta") Map meta) implements Annotated, Content { // @formatter:on + + public ImageContent(Annotations annotations, String data, String mimeType) { + this(annotations, data, mimeType, null); + } + + /** + * @deprecated Only exists for backwards-compatibility purposes. Use + * {@link ImageContent#ImageContent(Annotations, String, String)} instead. + */ + @Deprecated + public ImageContent(List audience, Double priority, String data, String mimeType) { + this(audience != null || priority != null ? new Annotations(audience, priority) : null, data, mimeType, + null); + } + + /** + * @deprecated Only exists for backwards-compatibility purposes. Use + * {@link ImageContent#annotations()} instead. + */ + @Deprecated + public List audience() { + return annotations == null ? null : annotations.audience(); + } + + /** + * @deprecated Only exists for backwards-compatibility purposes. Use + * {@link ImageContent#annotations()} instead. + */ + @Deprecated + public Double priority() { + return annotations == null ? null : annotations.priority(); + } + } + + /** + * Audio provided to or from an LLM. + * + * @param annotations Optional annotations for the client + * @param data The base64-encoded audio data + * @param mimeType The MIME type of the audio. Different providers may support + * different audio types + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record AudioContent( // @formatter:off + @JsonProperty("annotations") Annotations annotations, + @JsonProperty("data") String data, + @JsonProperty("mimeType") String mimeType, + @JsonProperty("_meta") Map meta) implements Annotated, Content { // @formatter:on + + // backwards compatibility constructor + public AudioContent(Annotations annotations, String data, String mimeType) { + this(annotations, data, mimeType, null); + } + } + + /** + * The contents of a resource, embedded into a prompt or tool call result. + * + * It is up to the client how best to render embedded resources for the benefit of the + * LLM and/or the user. + * + * @param annotations Optional annotations for the client + * @param resource The resource contents that are embedded + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record EmbeddedResource( // @formatter:off + @JsonProperty("annotations") Annotations annotations, + @JsonProperty("resource") ResourceContents resource, + @JsonProperty("_meta") Map meta) implements Annotated, Content { // @formatter:on + + // backwards compatibility constructor + public EmbeddedResource(Annotations annotations, ResourceContents resource) { + this(annotations, resource, null); + } + + /** + * @deprecated Only exists for backwards-compatibility purposes. Use + * {@link EmbeddedResource#EmbeddedResource(Annotations, ResourceContents)} + * instead. + */ + @Deprecated + public EmbeddedResource(List audience, Double priority, ResourceContents resource) { + this(audience != null || priority != null ? new Annotations(audience, priority) : null, resource, null); + } + + /** + * @deprecated Only exists for backwards-compatibility purposes. Use + * {@link EmbeddedResource#annotations()} instead. + */ + @Deprecated + public List audience() { + return annotations == null ? null : annotations.audience(); + } + + /** + * @deprecated Only exists for backwards-compatibility purposes. Use + * {@link EmbeddedResource#annotations()} instead. + */ + @Deprecated + public Double priority() { + return annotations == null ? null : annotations.priority(); + } + } + + /** + * A known resource that the server is capable of reading. + * + * @param uri the URI of the resource. + * @param name A human-readable name for this resource. This can be used by clients to + * populate UI elements. + * @param title A human-readable title for this resource. + * @param description A description of what this resource represents. This can be used + * by clients to improve the LLM's understanding of available resources. It can be + * thought of like a "hint" to the model. + * @param mimeType The MIME type of this resource, if known. + * @param size The size of the raw resource content, in bytes (i.e., before base64 + * encoding or any tokenization), if known. This can be used by Hosts to display file + * sizes and estimate context window usage. + * @param annotations Optional annotations for the client. The client can use + * annotations to inform how objects are used or displayed. + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ResourceLink( // @formatter:off + @JsonProperty("name") String name, + @JsonProperty("title") String title, + @JsonProperty("uri") String uri, + @JsonProperty("description") String description, + @JsonProperty("mimeType") String mimeType, + @JsonProperty("size") Long size, + @JsonProperty("annotations") Annotations annotations, + @JsonProperty("_meta") Map meta) implements Content, ResourceContent { // @formatter:on + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private String name; + + private String title; + + private String uri; + + private String description; + + private String mimeType; + + private Annotations annotations; + + private Long size; + + private Map meta; + + public Builder name(String name) { + this.name = name; + return this; + } + + public Builder title(String title) { + this.title = title; + return this; + } + + public Builder uri(String uri) { + this.uri = uri; + return this; + } + + public Builder description(String description) { + this.description = description; + return this; + } + + public Builder mimeType(String mimeType) { + this.mimeType = mimeType; + return this; + } + + public Builder annotations(Annotations annotations) { + this.annotations = annotations; + return this; + } + + public Builder size(Long size) { + this.size = size; + return this; + } + + public Builder meta(Map meta) { + this.meta = meta; + return this; + } + + public ResourceLink build() { + Assert.hasText(uri, "uri must not be empty"); + Assert.hasText(name, "name must not be empty"); + + return new ResourceLink(name, title, uri, description, mimeType, size, annotations, meta); + } + + } + } + + // --------------------------- + // Roots + // --------------------------- + /** + * Represents a root directory or file that the server can operate on. + * + * @param uri The URI identifying the root. This *must* start with file:// for now. + * This restriction may be relaxed in future versions of the protocol to allow other + * URI schemes. + * @param name An optional name for the root. This can be used to provide a + * human-readable identifier for the root, which may be useful for display purposes or + * for referencing the root in other parts of the application. + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record Root( // @formatter:off + @JsonProperty("uri") String uri, + @JsonProperty("name") String name, + @JsonProperty("_meta") Map meta) { // @formatter:on + + public Root(String uri, String name) { + this(uri, name, null); + } + } + + /** + * The client's response to a roots/list request from the server. This result contains + * an array of Root objects, each representing a root directory or file that the + * server can operate on. + * + * @param roots An array of Root objects, each representing a root directory or file + * that the server can operate on. + * @param nextCursor An optional cursor for pagination. If present, indicates there + * are more roots available. The client can use this cursor to request the next page + * of results by sending a roots/list request with the cursor parameter set to this + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ListRootsResult( // @formatter:off + @JsonProperty("roots") List roots, + @JsonProperty("nextCursor") String nextCursor, + @JsonProperty("_meta") Map meta) implements Result { // @formatter:on + + public ListRootsResult(List roots) { + this(roots, null); + } + + public ListRootsResult(List roots, String nextCursor) { + this(roots, nextCursor, null); + } + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java similarity index 64% rename from mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java rename to mcp-core/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java index 86906d859..241f7d8b5 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java @@ -1,3 +1,7 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + package io.modelcontextprotocol.spec; import java.time.Duration; @@ -7,8 +11,13 @@ import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; -import com.fasterxml.jackson.core.type.TypeReference; +import io.modelcontextprotocol.common.McpTransportContext; import io.modelcontextprotocol.server.McpAsyncServerExchange; +import io.modelcontextprotocol.server.McpInitRequestHandler; +import io.modelcontextprotocol.server.McpNotificationHandler; +import io.modelcontextprotocol.server.McpRequestHandler; +import io.modelcontextprotocol.json.TypeRef; +import io.modelcontextprotocol.util.Assert; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.publisher.Mono; @@ -16,10 +25,10 @@ import reactor.core.publisher.Sinks; /** - * Represents a Model Control Protocol (MCP) session on the server side. It manages + * Represents a Model Context Protocol (MCP) session on the server side. It manages * bidirectional JSON-RPC communication with the client. */ -public class McpServerSession implements McpSession { +public class McpServerSession implements McpLoggableSession { private static final Logger logger = LoggerFactory.getLogger(McpServerSession.class); @@ -32,13 +41,11 @@ public class McpServerSession implements McpSession { private final AtomicLong requestCounter = new AtomicLong(0); - private final InitRequestHandler initRequestHandler; - - private final InitNotificationHandler initNotificationHandler; + private final McpInitRequestHandler initRequestHandler; - private final Map> requestHandlers; + private final Map> requestHandlers; - private final Map notificationHandlers; + private final Map notificationHandlers; private final McpServerTransport transport; @@ -56,6 +63,29 @@ public class McpServerSession implements McpSession { private final AtomicInteger state = new AtomicInteger(STATE_UNINITIALIZED); + private volatile McpSchema.LoggingLevel minLoggingLevel = McpSchema.LoggingLevel.INFO; + + /** + * Creates a new server session with the given parameters and the transport to use. + * @param id session id + * @param transport the transport to use + * @param initHandler called when a + * {@link io.modelcontextprotocol.spec.McpSchema.InitializeRequest} is received by the + * server + * @param requestHandlers map of request handlers to use + * @param notificationHandlers map of notification handlers to use + */ + public McpServerSession(String id, Duration requestTimeout, McpServerTransport transport, + McpInitRequestHandler initHandler, Map> requestHandlers, + Map notificationHandlers) { + this.id = id; + this.requestTimeout = requestTimeout; + this.transport = transport; + this.initRequestHandler = initHandler; + this.requestHandlers = requestHandlers; + this.notificationHandlers = notificationHandlers; + } + /** * Creates a new server session with the given parameters and the transport to use. * @param id session id @@ -68,15 +98,18 @@ public class McpServerSession implements McpSession { * received. * @param requestHandlers map of request handlers to use * @param notificationHandlers map of notification handlers to use + * @deprecated Use + * {@link #McpServerSession(String, Duration, McpServerTransport, McpInitRequestHandler, Map, Map)} */ + @Deprecated public McpServerSession(String id, Duration requestTimeout, McpServerTransport transport, - InitRequestHandler initHandler, InitNotificationHandler initNotificationHandler, - Map> requestHandlers, Map notificationHandlers) { + McpInitRequestHandler initHandler, InitNotificationHandler initNotificationHandler, + Map> requestHandlers, + Map notificationHandlers) { this.id = id; this.requestTimeout = requestTimeout; this.transport = transport; this.initRequestHandler = initHandler; - this.initNotificationHandler = initNotificationHandler; this.requestHandlers = requestHandlers; this.notificationHandlers = notificationHandlers; } @@ -109,7 +142,18 @@ private String generateRequestId() { } @Override - public Mono sendRequest(String method, Object requestParams, TypeReference typeRef) { + public void setMinLoggingLevel(McpSchema.LoggingLevel minLoggingLevel) { + Assert.notNull(minLoggingLevel, "minLoggingLevel must not be null"); + this.minLoggingLevel = minLoggingLevel; + } + + @Override + public boolean isNotificationForLevelAllowed(McpSchema.LoggingLevel loggingLevel) { + return loggingLevel.level() >= this.minLoggingLevel.level(); + } + + @Override + public Mono sendRequest(String method, Object requestParams, TypeRef typeRef) { String requestId = this.generateRequestId(); return Mono.create(sink -> { @@ -154,26 +198,38 @@ public Mono sendNotification(String method, Object params) { * @return a Mono that completes when the message is processed */ public Mono handle(McpSchema.JSONRPCMessage message) { - return Mono.defer(() -> { + return Mono.deferContextual(ctx -> { + McpTransportContext transportContext = ctx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY); + // TODO handle errors for communication to without initialization happening // first if (message instanceof McpSchema.JSONRPCResponse response) { - logger.debug("Received Response: {}", response); - var sink = pendingResponses.remove(response.id()); - if (sink == null) { - logger.warn("Unexpected response for unknown id {}", response.id()); + logger.debug("Received response: {}", response); + if (response.id() != null) { + var sink = pendingResponses.remove(response.id()); + if (sink == null) { + logger.warn("Unexpected response for unknown id {}", response.id()); + } + else { + sink.success(response); + } } else { - sink.success(response); + logger.error("Discarded MCP request response without session id. " + + "This is an indication of a bug in the request sender code that can lead to memory " + + "leaks as pending requests will never be completed."); } return Mono.empty(); } else if (message instanceof McpSchema.JSONRPCRequest request) { logger.debug("Received request: {}", request); - return handleIncomingRequest(request).onErrorResume(error -> { + return handleIncomingRequest(request, transportContext).onErrorResume(error -> { + McpSchema.JSONRPCResponse.JSONRPCError jsonRpcError = (error instanceof McpError mcpError + && mcpError.getJsonRpcError() != null) ? mcpError.getJsonRpcError() + : new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, + error.getMessage(), McpError.aggregateExceptionMessages(error)); var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, - new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, - error.getMessage(), null)); + jsonRpcError); // TODO: Should the error go to SSE or back as POST return? return this.transport.sendMessage(errorResponse).then(Mono.empty()); }).flatMap(this.transport::sendMessage); @@ -183,7 +239,7 @@ else if (message instanceof McpSchema.JSONRPCNotification notification) { // happening first logger.debug("Received notification: {}", notification); // TODO: in case of error, should the POST request be signalled? - return handleIncomingNotification(notification) + return handleIncomingNotification(notification, transportContext) .doOnError(error -> logger.error("Error handling notification: {}", error.getMessage())); } else { @@ -196,15 +252,17 @@ else if (message instanceof McpSchema.JSONRPCNotification notification) { /** * Handles an incoming JSON-RPC request by routing it to the appropriate handler. * @param request The incoming JSON-RPC request + * @param transportContext * @return A Mono containing the JSON-RPC response */ - private Mono handleIncomingRequest(McpSchema.JSONRPCRequest request) { + private Mono handleIncomingRequest(McpSchema.JSONRPCRequest request, + McpTransportContext transportContext) { return Mono.defer(() -> { Mono resultMono; if (McpSchema.METHOD_INITIALIZE.equals(request.method())) { // TODO handle situation where already initialized! McpSchema.InitializeRequest initializeRequest = transport.unmarshalFrom(request.params(), - new TypeReference() { + new TypeRef() { }); this.state.lazySet(STATE_INITIALIZING); @@ -222,39 +280,62 @@ private Mono handleIncomingRequest(McpSchema.JSONRPCR error.message(), error.data()))); } - resultMono = this.exchangeSink.asMono().flatMap(exchange -> handler.handle(exchange, request.params())); + resultMono = this.exchangeSink.asMono() + .flatMap(exchange -> handler.handle(copyExchange(exchange, transportContext), request.params())); } return resultMono .map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), result, null)) - .onErrorResume(error -> Mono.just(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), - null, new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, - error.getMessage(), null)))); // TODO: add error message - // through the data field + .onErrorResume(error -> { + McpSchema.JSONRPCResponse.JSONRPCError jsonRpcError = (error instanceof McpError mcpError + && mcpError.getJsonRpcError() != null) ? mcpError.getJsonRpcError() + // TODO: add error message through the data field + : new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, + error.getMessage(), McpError.aggregateExceptionMessages(error)); + return Mono.just( + new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, jsonRpcError)); + }); }); } /** * Handles an incoming JSON-RPC notification by routing it to the appropriate handler. * @param notification The incoming JSON-RPC notification + * @param transportContext * @return A Mono that completes when the notification is processed */ - private Mono handleIncomingNotification(McpSchema.JSONRPCNotification notification) { + private Mono handleIncomingNotification(McpSchema.JSONRPCNotification notification, + McpTransportContext transportContext) { return Mono.defer(() -> { if (McpSchema.METHOD_NOTIFICATION_INITIALIZED.equals(notification.method())) { this.state.lazySet(STATE_INITIALIZED); - exchangeSink.tryEmitValue(new McpAsyncServerExchange(this, clientCapabilities.get(), clientInfo.get())); - return this.initNotificationHandler.handle(); + // FIXME: The session ID passed here is not the same as the one in the + // legacy SSE transport. + exchangeSink.tryEmitValue(new McpAsyncServerExchange(this.id, this, clientCapabilities.get(), + clientInfo.get(), transportContext)); } var handler = notificationHandlers.get(notification.method()); if (handler == null) { - logger.error("No handler registered for notification method: {}", notification.method()); + logger.warn("No handler registered for notification method: {}", notification); return Mono.empty(); } - return this.exchangeSink.asMono().flatMap(exchange -> handler.handle(exchange, notification.params())); + return this.exchangeSink.asMono() + .flatMap(exchange -> handler.handle(copyExchange(exchange, transportContext), notification.params())); }); } + /** + * This legacy implementation assumes an exchange is established upon the + * initialization phase see: exchangeSink.tryEmitValue(...), which creates a cached + * immutable exchange. Here, we create a new exchange and copy over everything from + * that cached exchange, and use it for a single HTTP request, with the transport + * context passed in. + */ + private McpAsyncServerExchange copyExchange(McpAsyncServerExchange exchange, McpTransportContext transportContext) { + return new McpAsyncServerExchange(exchange.sessionId(), this, exchange.getClientCapabilities(), + exchange.getClientInfo(), transportContext); + } + record MethodNotFoundError(String method, String message, Object data) { } @@ -264,17 +345,22 @@ private MethodNotFoundError getMethodNotFoundError(String method) { @Override public Mono closeGracefully() { + // TODO: clear pendingResponses and emit errors? return this.transport.closeGracefully(); } @Override public void close() { + // TODO: clear pendingResponses and emit errors? this.transport.close(); } /** * Request handler for the initialization request. + * + * @deprecated Use {@link McpInitRequestHandler} */ + @Deprecated public interface InitRequestHandler { /** @@ -301,7 +387,10 @@ public interface InitNotificationHandler { /** * A handler for client-initiated notifications. + * + * @deprecated Use {@link McpNotificationHandler} */ + @Deprecated public interface NotificationHandler { /** @@ -320,7 +409,9 @@ public interface NotificationHandler { * * @param the type of the response that is expected as a result of handling the * request. + * @deprecated Use {@link McpRequestHandler} */ + @Deprecated public interface RequestHandler { /** diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransport.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpServerTransport.java similarity index 78% rename from mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransport.java rename to mcp-core/src/main/java/io/modelcontextprotocol/spec/McpServerTransport.java index 632b8cee6..39c1644e0 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransport.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpServerTransport.java @@ -1,3 +1,7 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + package io.modelcontextprotocol.spec; /** diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java new file mode 100644 index 000000000..02028ccdf --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java @@ -0,0 +1,23 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.spec; + +/** + * Classic implementation of {@link McpServerTransportProviderBase} for a single outgoing + * stream in bidirectional communication (STDIO and the legacy HTTP SSE). + * + * @author Dariusz Jędrzejczyk + */ +public interface McpServerTransportProvider extends McpServerTransportProviderBase { + + /** + * Sets the session factory that will be used to create sessions for new clients. An + * implementation of the MCP server MUST call this method before any MCP interactions + * take place. + * @param sessionFactory the session factory to be used for initiating client sessions + */ + void setSessionFactory(McpServerSession.Factory sessionFactory); + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProviderBase.java similarity index 85% rename from mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java rename to mcp-core/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProviderBase.java index 5fdbd7ab6..acb1ecac6 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProviderBase.java @@ -1,5 +1,10 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + package io.modelcontextprotocol.spec; +import java.util.List; import java.util.Map; import reactor.core.publisher.Mono; @@ -29,15 +34,7 @@ * * @author Dariusz Jędrzejczyk */ -public interface McpServerTransportProvider { - - /** - * Sets the session factory that will be used to create sessions for new clients. An - * implementation of the MCP server MUST call this method before any MCP interactions - * take place. - * @param sessionFactory the session factory to be used for initiating client sessions - */ - void setSessionFactory(McpServerSession.Factory sessionFactory); +public interface McpServerTransportProviderBase { /** * Sends a notification to all connected clients. @@ -63,4 +60,12 @@ default void close() { */ Mono closeGracefully(); + /** + * Returns the protocol version supported by this transport provider. + * @return the protocol version as a string + */ + default List protocolVersions() { + return List.of(ProtocolVersions.MCP_2024_11_05); + } + } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSession.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpSession.java similarity index 93% rename from mcp/src/main/java/io/modelcontextprotocol/spec/McpSession.java rename to mcp-core/src/main/java/io/modelcontextprotocol/spec/McpSession.java index 473a860c2..767ed673e 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSession.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpSession.java @@ -4,13 +4,11 @@ package io.modelcontextprotocol.spec; -import java.util.Map; - -import com.fasterxml.jackson.core.type.TypeReference; +import io.modelcontextprotocol.json.TypeRef; import reactor.core.publisher.Mono; /** - * Represents a Model Control Protocol (MCP) session that handles communication between + * Represents a Model Context Protocol (MCP) session that handles communication between * clients and the server. This interface provides methods for sending requests and * notifications, as well as managing the session lifecycle. * @@ -39,7 +37,7 @@ public interface McpSession { * @param typeRef the TypeReference describing the expected response type * @return a Mono that will emit the response when received */ - Mono sendRequest(String method, Object requestParams, TypeReference typeRef); + Mono sendRequest(String method, Object requestParams, TypeRef typeRef); /** * Sends a notification to the model client or server without parameters. diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpStatelessServerTransport.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpStatelessServerTransport.java new file mode 100644 index 000000000..d1c2e5206 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpStatelessServerTransport.java @@ -0,0 +1,35 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.spec; + +import java.util.List; + +import io.modelcontextprotocol.server.McpStatelessServerHandler; +import reactor.core.publisher.Mono; + +public interface McpStatelessServerTransport { + + void setMcpHandler(McpStatelessServerHandler mcpHandler); + + /** + * Immediately closes all the transports with connected clients and releases any + * associated resources. + */ + default void close() { + this.closeGracefully().subscribe(); + } + + /** + * Gracefully closes all the transports with connected clients and releases any + * associated resources asynchronously. + * @return a {@link Mono} that completes when the connections have been closed. + */ + Mono closeGracefully(); + + default List protocolVersions() { + return List.of(ProtocolVersions.MCP_2025_03_26, ProtocolVersions.MCP_2025_06_18); + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java new file mode 100644 index 000000000..95f8959f5 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java @@ -0,0 +1,420 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.spec; + +import java.time.Duration; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import io.modelcontextprotocol.json.TypeRef; + +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.server.McpAsyncServerExchange; +import io.modelcontextprotocol.server.McpNotificationHandler; +import io.modelcontextprotocol.server.McpRequestHandler; +import io.modelcontextprotocol.spec.McpSchema.ErrorCodes; +import io.modelcontextprotocol.util.Assert; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.MonoSink; + +/** + * Representation of a Streamable HTTP server session that keeps track of mapping + * server-initiated requests to the client and mapping arriving responses. It also allows + * handling incoming notifications. For requests, it provides the default SSE streaming + * capability without the insight into the transport-specific details of HTTP handling. + * + * @author Dariusz Jędrzejczyk + * @author Yanming Zhou + */ +public class McpStreamableServerSession implements McpLoggableSession { + + private static final Logger logger = LoggerFactory.getLogger(McpStreamableServerSession.class); + + private final ConcurrentHashMap requestIdToStream = new ConcurrentHashMap<>(); + + private final String id; + + private final Duration requestTimeout; + + private final AtomicLong requestCounter = new AtomicLong(0); + + private final Map> requestHandlers; + + private final Map notificationHandlers; + + private final AtomicReference clientCapabilities = new AtomicReference<>(); + + private final AtomicReference clientInfo = new AtomicReference<>(); + + private final AtomicReference listeningStreamRef; + + private final MissingMcpTransportSession missingMcpTransportSession; + + private volatile McpSchema.LoggingLevel minLoggingLevel = McpSchema.LoggingLevel.INFO; + + /** + * Create an instance of the streamable session. + * @param id session ID + * @param clientCapabilities client capabilities + * @param clientInfo client info + * @param requestTimeout timeout to use for requests + * @param requestHandlers the map of MCP request handlers keyed by method name + * @param notificationHandlers the map of MCP notification handlers keyed by method + * name + */ + public McpStreamableServerSession(String id, McpSchema.ClientCapabilities clientCapabilities, + McpSchema.Implementation clientInfo, Duration requestTimeout, + Map> requestHandlers, + Map notificationHandlers) { + this.id = id; + this.missingMcpTransportSession = new MissingMcpTransportSession(id); + this.listeningStreamRef = new AtomicReference<>(this.missingMcpTransportSession); + this.clientCapabilities.lazySet(clientCapabilities); + this.clientInfo.lazySet(clientInfo); + this.requestTimeout = requestTimeout; + this.requestHandlers = requestHandlers; + this.notificationHandlers = notificationHandlers; + } + + @Override + public void setMinLoggingLevel(McpSchema.LoggingLevel minLoggingLevel) { + Assert.notNull(minLoggingLevel, "minLoggingLevel must not be null"); + this.minLoggingLevel = minLoggingLevel; + } + + @Override + public boolean isNotificationForLevelAllowed(McpSchema.LoggingLevel loggingLevel) { + return loggingLevel.level() >= this.minLoggingLevel.level(); + } + + /** + * Return the Session ID. + * @return session ID + */ + public String getId() { + return this.id; + } + + private String generateRequestId() { + return this.id + "-" + this.requestCounter.getAndIncrement(); + } + + @Override + public Mono sendRequest(String method, Object requestParams, TypeRef typeRef) { + return Mono.defer(() -> { + McpLoggableSession listeningStream = this.listeningStreamRef.get(); + return listeningStream.sendRequest(method, requestParams, typeRef); + }); + } + + @Override + public Mono sendNotification(String method, Object params) { + return Mono.defer(() -> { + McpLoggableSession listeningStream = this.listeningStreamRef.get(); + return listeningStream.sendNotification(method, params); + }); + } + + public Mono delete() { + return this.closeGracefully().then(Mono.fromRunnable(() -> { + // TODO: review in the context of history storage + // delete history, etc. + })); + } + + /** + * Create a listening stream (the generic HTTP GET request without Last-Event-ID + * header). + * @param transport The dedicated SSE transport stream + * @return a stream representation + */ + public McpStreamableServerSessionStream listeningStream(McpStreamableServerTransport transport) { + McpStreamableServerSessionStream listeningStream = new McpStreamableServerSessionStream(transport); + this.listeningStreamRef.set(listeningStream); + return listeningStream; + } + + // TODO: keep track of history by keeping a map from eventId to stream and then + // iterate over the events using the lastEventId + public Flux replay(Object lastEventId) { + return Flux.empty(); + } + + /** + * Provide the SSE stream of MCP messages finalized with a Response. + * @param jsonrpcRequest the MCP request triggering the stream creation + * @param transport the SSE transport stream to send messages to + * @return Mono which completes once the processing is done + */ + public Mono responseStream(McpSchema.JSONRPCRequest jsonrpcRequest, McpStreamableServerTransport transport) { + return Mono.deferContextual(ctx -> { + McpTransportContext transportContext = ctx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY); + + McpStreamableServerSessionStream stream = new McpStreamableServerSessionStream(transport); + McpRequestHandler requestHandler = McpStreamableServerSession.this.requestHandlers + .get(jsonrpcRequest.method()); + // TODO: delegate to stream, which upon successful response should close + // remove itself from the registry and also close the underlying transport + // (sink) + if (requestHandler == null) { + MethodNotFoundError error = getMethodNotFoundError(jsonrpcRequest.method()); + return transport + .sendMessage(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, jsonrpcRequest.id(), null, + new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.METHOD_NOT_FOUND, + error.message(), error.data()))); + } + return requestHandler + .handle(new McpAsyncServerExchange(this.id, stream, clientCapabilities.get(), clientInfo.get(), + transportContext), jsonrpcRequest.params()) + .map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, jsonrpcRequest.id(), result, + null)) + .onErrorResume(e -> { + McpSchema.JSONRPCResponse.JSONRPCError jsonRpcError = (e instanceof McpError mcpError + && mcpError.getJsonRpcError() != null) ? mcpError.getJsonRpcError() + : new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, + e.getMessage(), McpError.aggregateExceptionMessages(e)); + + var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, jsonrpcRequest.id(), + null, jsonRpcError); + return Mono.just(errorResponse); + }) + .flatMap(transport::sendMessage) + .then(transport.closeGracefully()); + }); + } + + /** + * Handle the MCP notification. + * @param notification MCP notification + * @return Mono which completes upon succesful handling + */ + public Mono accept(McpSchema.JSONRPCNotification notification) { + return Mono.deferContextual(ctx -> { + McpTransportContext transportContext = ctx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY); + McpNotificationHandler notificationHandler = this.notificationHandlers.get(notification.method()); + if (notificationHandler == null) { + logger.warn("No handler registered for notification method: {}", notification); + return Mono.empty(); + } + McpLoggableSession listeningStream = this.listeningStreamRef.get(); + return notificationHandler.handle(new McpAsyncServerExchange(this.id, listeningStream, + this.clientCapabilities.get(), this.clientInfo.get(), transportContext), notification.params()); + }); + + } + + /** + * Handle the MCP response. + * @param response MCP response to the server-initiated request + * @return Mono which completes upon successful processing + */ + public Mono accept(McpSchema.JSONRPCResponse response) { + return Mono.defer(() -> { + logger.debug("Received response: {}", response); + + if (response.id() != null) { + var stream = this.requestIdToStream.get(response.id()); + if (stream == null) { + return Mono.error(McpError.builder(ErrorCodes.INTERNAL_ERROR) + .message("Unexpected response for unknown id " + response.id()) + .build()); + } + // TODO: encapsulate this inside the stream itself + var sink = stream.pendingResponses.remove(response.id()); + if (sink == null) { + return Mono.error(McpError.builder(ErrorCodes.INTERNAL_ERROR) + .message("Unexpected response for unknown id " + response.id()) + .build()); + } + else { + sink.success(response); + } + } + else { + logger.error("Discarded MCP request response without session id. " + + "This is an indication of a bug in the request sender code that can lead to memory " + + "leaks as pending requests will never be completed."); + } + return Mono.empty(); + }); + } + + record MethodNotFoundError(String method, String message, Object data) { + } + + private MethodNotFoundError getMethodNotFoundError(String method) { + return new MethodNotFoundError(method, "Method not found: " + method, null); + } + + @Override + public Mono closeGracefully() { + return Mono.defer(() -> { + McpLoggableSession listeningStream = this.listeningStreamRef.getAndSet(missingMcpTransportSession); + return listeningStream.closeGracefully(); + // TODO: Also close all the open streams + }); + } + + @Override + public void close() { + McpLoggableSession listeningStream = this.listeningStreamRef.getAndSet(missingMcpTransportSession); + if (listeningStream != null) { + listeningStream.close(); + } + // TODO: Also close all open streams + } + + /** + * Request handler for the initialization request. + */ + public interface InitRequestHandler { + + /** + * Handles the initialization request. + * @param initializeRequest the initialization request by the client + * @return a Mono that will emit the result of the initialization + */ + Mono handle(McpSchema.InitializeRequest initializeRequest); + + } + + /** + * Factory for new Streamable HTTP MCP sessions. + */ + public interface Factory { + + /** + * Given an initialize request, create a composite for the session initialization + * @param initializeRequest the initialization request from the client + * @return a composite allowing the session to start + */ + McpStreamableServerSessionInit startSession(McpSchema.InitializeRequest initializeRequest); + + } + + /** + * Composite holding the {@link McpStreamableServerSession} and the initialization + * result + * + * @param session the session instance + * @param initResult the result to use to respond to the client + */ + public record McpStreamableServerSessionInit(McpStreamableServerSession session, + Mono initResult) { + } + + /** + * An individual SSE stream within a Streamable HTTP context. Can be either the + * listening GET SSE stream or a request-specific POST SSE stream. + */ + public final class McpStreamableServerSessionStream implements McpLoggableSession { + + private final ConcurrentHashMap> pendingResponses = new ConcurrentHashMap<>(); + + private final McpStreamableServerTransport transport; + + private final String transportId; + + private final Supplier uuidGenerator; + + /** + * Constructor accepting the dedicated transport representing the SSE stream. + * @param transport request-specific SSE transport stream + */ + public McpStreamableServerSessionStream(McpStreamableServerTransport transport) { + this.transport = transport; + this.transportId = UUID.randomUUID().toString(); + // This ID design allows for a constant-time extraction of the history by + // precisely identifying the SSE stream using the first component + this.uuidGenerator = () -> this.transportId + "_" + UUID.randomUUID(); + } + + @Override + public void setMinLoggingLevel(McpSchema.LoggingLevel minLoggingLevel) { + Assert.notNull(minLoggingLevel, "minLoggingLevel must not be null"); + McpStreamableServerSession.this.setMinLoggingLevel(minLoggingLevel); + } + + @Override + public boolean isNotificationForLevelAllowed(McpSchema.LoggingLevel loggingLevel) { + return McpStreamableServerSession.this.isNotificationForLevelAllowed(loggingLevel); + } + + @Override + public Mono sendRequest(String method, Object requestParams, TypeRef typeRef) { + String requestId = McpStreamableServerSession.this.generateRequestId(); + + McpStreamableServerSession.this.requestIdToStream.put(requestId, this); + + return Mono.create(sink -> { + this.pendingResponses.put(requestId, sink); + McpSchema.JSONRPCRequest jsonrpcRequest = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, + method, requestId, requestParams); + String messageId = this.uuidGenerator.get(); + // TODO: store message in history + this.transport.sendMessage(jsonrpcRequest, messageId).subscribe(v -> { + }, sink::error); + }).timeout(requestTimeout).doOnError(e -> { + this.pendingResponses.remove(requestId); + McpStreamableServerSession.this.requestIdToStream.remove(requestId); + }).handle((jsonRpcResponse, sink) -> { + if (jsonRpcResponse.error() != null) { + sink.error(new McpError(jsonRpcResponse.error())); + } + else { + if (typeRef.getType().equals(Void.class)) { + sink.complete(); + } + else { + sink.next(this.transport.unmarshalFrom(jsonRpcResponse.result(), typeRef)); + } + } + }); + } + + @Override + public Mono sendNotification(String method, Object params) { + McpSchema.JSONRPCNotification jsonrpcNotification = new McpSchema.JSONRPCNotification( + McpSchema.JSONRPC_VERSION, method, params); + String messageId = this.uuidGenerator.get(); + // TODO: store message in history + return this.transport.sendMessage(jsonrpcNotification, messageId); + } + + @Override + public Mono closeGracefully() { + return Mono.defer(() -> { + this.pendingResponses.values().forEach(s -> s.error(new RuntimeException("Stream closed"))); + this.pendingResponses.clear(); + // If this was the generic stream, reset it + McpStreamableServerSession.this.listeningStreamRef.compareAndExchange(this, + McpStreamableServerSession.this.missingMcpTransportSession); + McpStreamableServerSession.this.requestIdToStream.values().removeIf(this::equals); + return this.transport.closeGracefully(); + }); + } + + @Override + public void close() { + this.pendingResponses.values().forEach(s -> s.error(new RuntimeException("Stream closed"))); + this.pendingResponses.clear(); + // If this was the generic stream, reset it + McpStreamableServerSession.this.listeningStreamRef.compareAndExchange(this, + McpStreamableServerSession.this.missingMcpTransportSession); + McpStreamableServerSession.this.requestIdToStream.values().removeIf(this::equals); + this.transport.close(); + } + + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerTransport.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerTransport.java new file mode 100644 index 000000000..f53c68900 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerTransport.java @@ -0,0 +1,24 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.spec; + +import reactor.core.publisher.Mono; + +/** + * Streamable HTTP server transport representing an individual SSE stream. + * + * @author Dariusz Jędrzejczyk + */ +public interface McpStreamableServerTransport extends McpServerTransport { + + /** + * Send a message to the client with a message ID for use in the SSE event payload + * @param message the JSON-RPC payload + * @param messageId message id for SSE events + * @return Mono which completes when done + */ + Mono sendMessage(McpSchema.JSONRPCMessage message, String messageId); + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerTransportProvider.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerTransportProvider.java new file mode 100644 index 000000000..09fe9fb0e --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerTransportProvider.java @@ -0,0 +1,69 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.spec; + +import reactor.core.publisher.Mono; + +/** + * The core building block providing the server-side MCP transport for Streamable HTTP + * servers. Implement this interface to bridge between a particular server-side technology + * and the MCP server transport layer. + * + *

+ * The lifecycle of the provider dictates that it be created first, upon application + * startup, and then passed into either + * {@link io.modelcontextprotocol.server.McpServer#sync(McpStreamableServerTransportProvider)} + * or + * {@link io.modelcontextprotocol.server.McpServer#async(McpStreamableServerTransportProvider)}. + * As a result of the MCP server creation, the provider will be notified of a + * {@link McpStreamableServerSession.Factory} which will be used to handle a 1:1 + * communication between a newly connected client and the server using a session concept. + * The provider's responsibility is to create instances of + * {@link McpStreamableServerTransport} that the session will utilise during the session + * lifetime. + * + *

+ * Finally, the {@link McpStreamableServerTransport}s can be closed in bulk when + * {@link #close()} or {@link #closeGracefully()} are called as part of the normal + * application shutdown event. Individual {@link McpStreamableServerTransport}s can also + * be closed on a per-session basis, where the {@link McpServerSession#close()} or + * {@link McpServerSession#closeGracefully()} closes the provided transport. + * + * @author Dariusz Jędrzejczyk + */ +public interface McpStreamableServerTransportProvider extends McpServerTransportProviderBase { + + /** + * Sets the session factory that will be used to create sessions for new clients. An + * implementation of the MCP server MUST call this method before any MCP interactions + * take place. + * @param sessionFactory the session factory to be used for initiating client sessions + */ + void setSessionFactory(McpStreamableServerSession.Factory sessionFactory); + + /** + * Sends a notification to all connected clients. + * @param method the name of the notification method to be called on the clients + * @param params parameters to be sent with the notification + * @return a Mono that completes when the notification has been broadcast + */ + Mono notifyClients(String method, Object params); + + /** + * Immediately closes all the transports with connected clients and releases any + * associated resources. + */ + default void close() { + this.closeGracefully().subscribe(); + } + + /** + * Gracefully closes all the transports with connected clients and releases any + * associated resources asynchronously. + * @return a {@link Mono} that completes when the connections have been closed. + */ + Mono closeGracefully(); + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransport.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransport.java similarity index 91% rename from mcp/src/main/java/io/modelcontextprotocol/spec/McpTransport.java rename to mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransport.java index 40d9ba7ac..0a732bab6 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransport.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransport.java @@ -4,8 +4,10 @@ package io.modelcontextprotocol.spec; -import com.fasterxml.jackson.core.type.TypeReference; +import java.util.List; + import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; +import io.modelcontextprotocol.json.TypeRef; import reactor.core.publisher.Mono; /** @@ -75,6 +77,10 @@ default void close() { * @param typeRef the type reference for the object to unmarshal * @return the unmarshalled object */ - T unmarshalFrom(Object data, TypeReference typeRef); + T unmarshalFrom(Object data, TypeRef typeRef); + + default List protocolVersions() { + return List.of(ProtocolVersions.MCP_2024_11_05); + } } diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransportException.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransportException.java new file mode 100644 index 000000000..cfd3dae31 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransportException.java @@ -0,0 +1,38 @@ +/* +* Copyright 2025 - 2025 the original author or authors. +*/ +package io.modelcontextprotocol.spec; + +/** + * Exception thrown when there is an issue with the transport layer of the Model Context + * Protocol (MCP). + * + *

+ * This exception is used to indicate errors that occur during communication between the + * MCP client and server, such as connection failures, protocol violations, or unexpected + * responses. + * + * @author Christian Tzolov + */ +public class McpTransportException extends RuntimeException { + + private static final long serialVersionUID = 1L; + + public McpTransportException(String message) { + super(message); + } + + public McpTransportException(String message, Throwable cause) { + super(message, cause); + } + + public McpTransportException(Throwable cause) { + super(cause); + } + + public McpTransportException(String message, Throwable cause, boolean enableSuppression, + boolean writableStackTrace) { + super(message, cause, enableSuppression, writableStackTrace); + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportSession.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransportSession.java similarity index 96% rename from mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportSession.java rename to mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransportSession.java index 555f018f8..68f0fc5bb 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportSession.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransportSession.java @@ -1,9 +1,13 @@ -package io.modelcontextprotocol.spec; +/* + * Copyright 2024-2025 the original author or authors. + */ -import org.reactivestreams.Publisher; +package io.modelcontextprotocol.spec; import java.util.Optional; +import org.reactivestreams.Publisher; + /** * An abstraction of the session as perceived from the MCP transport layer. Not to be * confused with the {@link McpSession} type that operates at the level of the JSON-RPC diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransportSessionClosedException.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransportSessionClosedException.java new file mode 100644 index 000000000..60e2850b9 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransportSessionClosedException.java @@ -0,0 +1,23 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package io.modelcontextprotocol.spec; + +import reactor.util.annotation.Nullable; + +/** + * Exception thrown when trying to use an {@link McpTransportSession} that has been + * closed. + * + * @see ClosedMcpTransportSession + * @author Daniel Garnier-Moiroux + */ +public class McpTransportSessionClosedException extends RuntimeException { + + public McpTransportSessionClosedException(@Nullable String sessionId) { + super(sessionId != null ? "MCP session with ID %s has been closed".formatted(sessionId) + : "MCP session has been closed"); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportSessionNotFoundException.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransportSessionNotFoundException.java similarity index 93% rename from mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportSessionNotFoundException.java rename to mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransportSessionNotFoundException.java index 474a18ae0..eced49ec3 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportSessionNotFoundException.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransportSessionNotFoundException.java @@ -1,3 +1,7 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + package io.modelcontextprotocol.spec; /** diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportStream.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransportStream.java similarity index 96% rename from mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportStream.java rename to mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransportStream.java index 2d6dcce75..322afda63 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportStream.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransportStream.java @@ -1,3 +1,7 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + package io.modelcontextprotocol.spec; import org.reactivestreams.Publisher; diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/MissingMcpTransportSession.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/MissingMcpTransportSession.java new file mode 100644 index 000000000..0bf70d5b8 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/MissingMcpTransportSession.java @@ -0,0 +1,63 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.spec; + +import io.modelcontextprotocol.json.TypeRef; +import io.modelcontextprotocol.util.Assert; +import reactor.core.publisher.Mono; + +/** + * A {@link McpLoggableSession} which represents a missing stream that would allow the + * server to communicate with the client. Specifically, it can be used when a Streamable + * HTTP client has not opened a listening SSE stream to accept messages for interactions + * unrelated with concurrently running client-initiated requests. + * + * @author Dariusz Jędrzejczyk + */ +public class MissingMcpTransportSession implements McpLoggableSession { + + private final String sessionId; + + private volatile McpSchema.LoggingLevel minLoggingLevel = McpSchema.LoggingLevel.INFO; + + /** + * Create an instance with the Session ID specified. + * @param sessionId session ID + */ + public MissingMcpTransportSession(String sessionId) { + this.sessionId = sessionId; + } + + @Override + public Mono sendRequest(String method, Object requestParams, TypeRef typeRef) { + return Mono.error(new IllegalStateException("Stream unavailable for session " + this.sessionId)); + } + + @Override + public Mono sendNotification(String method, Object params) { + return Mono.error(new IllegalStateException("Stream unavailable for session " + this.sessionId)); + } + + @Override + public Mono closeGracefully() { + return Mono.empty(); + } + + @Override + public void close() { + } + + @Override + public void setMinLoggingLevel(McpSchema.LoggingLevel minLoggingLevel) { + Assert.notNull(minLoggingLevel, "minLoggingLevel must not be null"); + this.minLoggingLevel = minLoggingLevel; + } + + @Override + public boolean isNotificationForLevelAllowed(McpSchema.LoggingLevel loggingLevel) { + return loggingLevel.level() >= this.minLoggingLevel.level(); + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/ProtocolVersions.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/ProtocolVersions.java new file mode 100644 index 000000000..d3d34db62 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/ProtocolVersions.java @@ -0,0 +1,29 @@ +package io.modelcontextprotocol.spec; + +public interface ProtocolVersions { + + /** + * MCP protocol version for 2024-11-05. + * https://modelcontextprotocol.io/specification/2024-11-05 + */ + String MCP_2024_11_05 = "2024-11-05"; + + /** + * MCP protocol version for 2025-03-26. + * https://modelcontextprotocol.io/specification/2025-03-26 + */ + String MCP_2025_03_26 = "2025-03-26"; + + /** + * MCP protocol version for 2025-06-18. + * https://modelcontextprotocol.io/specification/2025-06-18 + */ + String MCP_2025_06_18 = "2025-06-18"; + + /** + * MCP protocol version for 2025-11-25. + * https://modelcontextprotocol.io/specification/2025-11-25 + */ + String MCP_2025_11_25 = "2025-11-25"; + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/Assert.java b/mcp-core/src/main/java/io/modelcontextprotocol/util/Assert.java similarity index 84% rename from mcp/src/main/java/io/modelcontextprotocol/util/Assert.java rename to mcp-core/src/main/java/io/modelcontextprotocol/util/Assert.java index d68188c6f..1fa6b3058 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/util/Assert.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/util/Assert.java @@ -76,4 +76,17 @@ public static boolean hasText(@Nullable String str) { return (str != null && !str.isBlank()); } + /** + * Assert a boolean expression, throwing an {@code IllegalArgumentException} if the + * expression evaluates to {@code false}. + * @param expression a boolean expression + * @param message the exception message to use if the assertion fails + * @throws IllegalArgumentException if {@code expression} is {@code false} + */ + public static void isTrue(boolean expression, String message) { + if (!expression) { + throw new IllegalArgumentException(message); + } + } + } diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/DefaultMcpUriTemplateManager.java b/mcp-core/src/main/java/io/modelcontextprotocol/util/DefaultMcpUriTemplateManager.java similarity index 81% rename from mcp/src/main/java/io/modelcontextprotocol/util/DefaultMcpUriTemplateManager.java rename to mcp-core/src/main/java/io/modelcontextprotocol/util/DefaultMcpUriTemplateManager.java index b2e9a5285..c3b922edf 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/util/DefaultMcpUriTemplateManager.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/util/DefaultMcpUriTemplateManager.java @@ -33,9 +33,7 @@ public class DefaultMcpUriTemplateManager implements McpUriTemplateManager { * @param uriTemplate The URI template to be used for variable extraction */ public DefaultMcpUriTemplateManager(String uriTemplate) { - if (uriTemplate == null || uriTemplate.isEmpty()) { - throw new IllegalArgumentException("URI template must not be null or empty"); - } + Assert.hasText(uriTemplate, "URI template must not be null or empty"); this.uriTemplate = uriTemplate; } @@ -48,10 +46,6 @@ public DefaultMcpUriTemplateManager(String uriTemplate) { */ @Override public List getVariableNames() { - if (uriTemplate == null || uriTemplate.isEmpty()) { - return List.of(); - } - List variables = new ArrayList<>(); Matcher matcher = URI_VARIABLE_PATTERN.matcher(this.uriTemplate); @@ -81,7 +75,7 @@ public Map extractVariableValues(String requestUri) { Map variableValues = new HashMap<>(); List uriVariables = this.getVariableNames(); - if (requestUri == null || uriVariables.isEmpty()) { + if (!Utils.hasText(requestUri) || uriVariables.isEmpty()) { return variableValues; } @@ -147,12 +141,30 @@ public boolean matches(String uri) { return uri.equals(this.uriTemplate); } - // Convert the pattern to a regex - String regex = this.uriTemplate.replaceAll("\\{[^/]+?\\}", "([^/]+?)"); - regex = regex.replace("/", "\\/"); + // Convert the URI template into a robust regex pattern that escapes special + // characters like '?'. + StringBuilder patternBuilder = new StringBuilder("^"); + Matcher variableMatcher = URI_VARIABLE_PATTERN.matcher(this.uriTemplate); + int lastEnd = 0; + + while (variableMatcher.find()) { + // Append the literal part of the template, safely quoted + String textBefore = this.uriTemplate.substring(lastEnd, variableMatcher.start()); + patternBuilder.append(Pattern.quote(textBefore)); + // Append a capturing group for the variable itself + patternBuilder.append("([^/]+?)"); + lastEnd = variableMatcher.end(); + } + + // Append any remaining literal text after the last variable + if (lastEnd < this.uriTemplate.length()) { + patternBuilder.append(Pattern.quote(this.uriTemplate.substring(lastEnd))); + } + + patternBuilder.append("$"); // Check if the URI matches the regex - return Pattern.compile(regex).matcher(uri).matches(); + return Pattern.compile(patternBuilder.toString()).matcher(uri).matches(); } @Override diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/DeafaultMcpUriTemplateManagerFactory.java b/mcp-core/src/main/java/io/modelcontextprotocol/util/DefaultMcpUriTemplateManagerFactory.java similarity index 86% rename from mcp/src/main/java/io/modelcontextprotocol/util/DeafaultMcpUriTemplateManagerFactory.java rename to mcp-core/src/main/java/io/modelcontextprotocol/util/DefaultMcpUriTemplateManagerFactory.java index 3870b76fc..fd1a3bd71 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/util/DeafaultMcpUriTemplateManagerFactory.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/util/DefaultMcpUriTemplateManagerFactory.java @@ -1,12 +1,13 @@ /* * Copyright 2025 - 2025 the original author or authors. */ + package io.modelcontextprotocol.util; /** * @author Christian Tzolov */ -public class DeafaultMcpUriTemplateManagerFactory implements McpUriTemplateManagerFactory { +public class DefaultMcpUriTemplateManagerFactory implements McpUriTemplateManagerFactory { /** * Creates a new instance of {@link McpUriTemplateManager} with the specified URI diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/util/KeepAliveScheduler.java b/mcp-core/src/main/java/io/modelcontextprotocol/util/KeepAliveScheduler.java new file mode 100644 index 000000000..6d53ed516 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/util/KeepAliveScheduler.java @@ -0,0 +1,217 @@ +/** + * Copyright 2025 - 2025 the original author or authors. + */ + +package io.modelcontextprotocol.util; + +import java.time.Duration; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Supplier; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import io.modelcontextprotocol.json.TypeRef; + +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSession; +import reactor.core.Disposable; +import reactor.core.publisher.Flux; +import reactor.core.scheduler.Scheduler; +import reactor.core.scheduler.Schedulers; + +/** + * A utility class for scheduling regular keep-alive calls to maintain connections. It + * sends periodic keep-alive, ping, messages to connected mcp clients to prevent idle + * timeouts. + * + * The pings are sent to all active mcp sessions at regular intervals. + * + * @author Christian Tzolov + */ +public class KeepAliveScheduler { + + private static final Logger logger = LoggerFactory.getLogger(KeepAliveScheduler.class); + + private static final TypeRef OBJECT_TYPE_REF = new TypeRef<>() { + }; + + /** Initial delay before the first keepAlive call */ + private final Duration initialDelay; + + /** Interval between subsequent keepAlive calls */ + private final Duration interval; + + /** The scheduler used for executing keepAlive calls */ + private final Scheduler scheduler; + + /** The current state of the scheduler */ + private final AtomicBoolean isRunning = new AtomicBoolean(false); + + /** The current subscription for the keepAlive calls */ + private Disposable currentSubscription; + + // TODO Currently we do not support the streams (streamable http session created by + // http post/get) + + /** Supplier for reactive McpSession instances */ + private final Supplier> mcpSessions; + + /** + * Creates a KeepAliveScheduler with a custom scheduler, initial delay, interval and a + * supplier for McpSession instances. + * @param scheduler The scheduler to use for executing keepAlive calls + * @param initialDelay Initial delay before the first keepAlive call + * @param interval Interval between subsequent keepAlive calls + * @param mcpSessions Supplier for McpSession instances + */ + KeepAliveScheduler(Scheduler scheduler, Duration initialDelay, Duration interval, + Supplier> mcpSessions) { + this.scheduler = scheduler; + this.initialDelay = initialDelay; + this.interval = interval; + this.mcpSessions = mcpSessions; + } + + /** + * Creates a new Builder instance for constructing KeepAliveScheduler. + * @return A new Builder instance + */ + public static Builder builder(Supplier> mcpSessions) { + return new Builder(mcpSessions); + } + + /** + * Starts regular keepAlive calls with sessions supplier. + * @return Disposable to control the scheduled execution + */ + public Disposable start() { + if (this.isRunning.compareAndSet(false, true)) { + + this.currentSubscription = Flux.interval(this.initialDelay, this.interval, this.scheduler) + .doOnNext(tick -> { + this.mcpSessions.get() + .flatMap(session -> session.sendRequest(McpSchema.METHOD_PING, null, OBJECT_TYPE_REF) + .doOnError(e -> logger.warn("Failed to send keep-alive ping to session {}: {}", session, + e.getMessage())) + .onErrorComplete()) + .subscribe(); + }) + .doOnCancel(() -> this.isRunning.set(false)) + .doOnComplete(() -> this.isRunning.set(false)) + .onErrorComplete(error -> { + logger.error("KeepAlive scheduler error", error); + this.isRunning.set(false); + return true; + }) + .subscribe(); + + return this.currentSubscription; + } + else { + throw new IllegalStateException("KeepAlive scheduler is already running. Stop it first."); + } + } + + /** + * Stops the currently running keepAlive scheduler. + */ + public void stop() { + if (this.currentSubscription != null && !this.currentSubscription.isDisposed()) { + this.currentSubscription.dispose(); + } + this.isRunning.set(false); + } + + /** + * Checks if the scheduler is currently running. + * @return true if running, false otherwise + */ + public boolean isRunning() { + return this.isRunning.get(); + } + + /** + * Shuts down the scheduler and releases resources. + */ + public void shutdown() { + stop(); + if (this.scheduler instanceof Disposable) { + ((Disposable) this.scheduler).dispose(); + } + } + + /** + * Builder class for creating KeepAliveScheduler instances with fluent API. + */ + public static class Builder { + + private Scheduler scheduler = Schedulers.boundedElastic(); + + private Duration initialDelay = Duration.ofSeconds(0); + + private Duration interval = Duration.ofSeconds(30); + + private Supplier> mcpSessions; + + /** + * Creates a new Builder instance with a supplier for McpSession instances. + * @param mcpSessions The supplier for McpSession instances + */ + Builder(Supplier> mcpSessions) { + Assert.notNull(mcpSessions, "McpSessions supplier must not be null"); + this.mcpSessions = mcpSessions; + } + + /** + * Sets the scheduler to use for executing keepAlive calls. + * @param scheduler The scheduler to use: + *
    + *
  • Schedulers.single() - single-threaded scheduler
  • + *
  • Schedulers.boundedElastic() - bounded elastic scheduler for I/O operations + * (Default)
  • + *
  • Schedulers.parallel() - parallel scheduler for CPU-intensive + * operations
  • + *
  • Schedulers.immediate() - immediate scheduler for synchronous execution
  • + *
+ * @return This builder instance for method chaining + */ + public Builder scheduler(Scheduler scheduler) { + Assert.notNull(scheduler, "Scheduler must not be null"); + this.scheduler = scheduler; + return this; + } + + /** + * Sets the initial delay before the first keepAlive call. + * @param initialDelay The initial delay duration + * @return This builder instance for method chaining + */ + public Builder initialDelay(Duration initialDelay) { + Assert.notNull(initialDelay, "Initial delay must not be null"); + this.initialDelay = initialDelay; + return this; + } + + /** + * Sets the interval between subsequent keepAlive calls. + * @param interval The interval duration + * @return This builder instance for method chaining + */ + public Builder interval(Duration interval) { + Assert.notNull(interval, "Interval must not be null"); + this.interval = interval; + return this; + } + + /** + * Builds and returns a new KeepAliveScheduler instance. + * @return A new KeepAliveScheduler configured with the builder's settings + */ + public KeepAliveScheduler build() { + return new KeepAliveScheduler(scheduler, initialDelay, interval, mcpSessions); + } + + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/McpUriTemplateManager.java b/mcp-core/src/main/java/io/modelcontextprotocol/util/McpUriTemplateManager.java similarity index 100% rename from mcp/src/main/java/io/modelcontextprotocol/util/McpUriTemplateManager.java rename to mcp-core/src/main/java/io/modelcontextprotocol/util/McpUriTemplateManager.java diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/McpUriTemplateManagerFactory.java b/mcp-core/src/main/java/io/modelcontextprotocol/util/McpUriTemplateManagerFactory.java similarity index 99% rename from mcp/src/main/java/io/modelcontextprotocol/util/McpUriTemplateManagerFactory.java rename to mcp-core/src/main/java/io/modelcontextprotocol/util/McpUriTemplateManagerFactory.java index 9644f9a6c..389727b45 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/util/McpUriTemplateManagerFactory.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/util/McpUriTemplateManagerFactory.java @@ -1,6 +1,7 @@ /* * Copyright 2025 - 2025 the original author or authors. */ + package io.modelcontextprotocol.util; /** diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java b/mcp-core/src/main/java/io/modelcontextprotocol/util/Utils.java similarity index 98% rename from mcp/src/main/java/io/modelcontextprotocol/util/Utils.java rename to mcp-core/src/main/java/io/modelcontextprotocol/util/Utils.java index 8e654e596..cd420100c 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/util/Utils.java @@ -4,12 +4,12 @@ package io.modelcontextprotocol.util; -import reactor.util.annotation.Nullable; - import java.net.URI; import java.util.Collection; import java.util.Map; +import reactor.util.annotation.Nullable; + /** * Miscellaneous utility methods. * @@ -69,6 +69,9 @@ public static boolean isEmpty(@Nullable Map map) { * base URL or URI is malformed */ public static URI resolveUri(URI baseUrl, String endpointUrl) { + if (!Utils.hasText(endpointUrl)) { + return baseUrl; + } URI endpointUri = URI.create(endpointUrl); if (endpointUri.isAbsolute() && !isUnderBaseUri(baseUrl, endpointUri)) { throw new IllegalArgumentException("Absolute endpoint URL does not match the base URL."); diff --git a/mcp/src/test/java/io/modelcontextprotocol/McpUriTemplateManagerTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/McpUriTemplateManagerTests.java similarity index 86% rename from mcp/src/test/java/io/modelcontextprotocol/McpUriTemplateManagerTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/McpUriTemplateManagerTests.java index 6f041daa6..8f68f0d6e 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/McpUriTemplateManagerTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/McpUriTemplateManagerTests.java @@ -12,7 +12,7 @@ import java.util.List; import java.util.Map; -import io.modelcontextprotocol.util.DeafaultMcpUriTemplateManagerFactory; +import io.modelcontextprotocol.util.DefaultMcpUriTemplateManagerFactory; import io.modelcontextprotocol.util.McpUriTemplateManager; import io.modelcontextprotocol.util.McpUriTemplateManagerFactory; import org.junit.jupiter.api.BeforeEach; @@ -29,7 +29,7 @@ public class McpUriTemplateManagerTests { @BeforeEach void setUp() { - this.uriTemplateFactory = new DeafaultMcpUriTemplateManagerFactory(); + this.uriTemplateFactory = new DefaultMcpUriTemplateManagerFactory(); } @Test @@ -94,4 +94,13 @@ void shouldMatchUriAgainstTemplatePattern() { assertFalse(uriTemplateManager.matches("/api/users/123/comments/456")); } + @Test + void shouldMatchUriWithQueryParameters() { + String templateWithQuery = "file://name/search?={search}"; + var uriTemplateManager = this.uriTemplateFactory.create(templateWithQuery); + + assertTrue(uriTemplateManager.matches("file://name/search?=abcd"), + "Should correctly match a URI containing query parameters."); + } + } diff --git a/mcp/src/test/java/io/modelcontextprotocol/MockMcpClientTransport.java b/mcp-core/src/test/java/io/modelcontextprotocol/MockMcpClientTransport.java similarity index 84% rename from mcp/src/test/java/io/modelcontextprotocol/MockMcpClientTransport.java rename to mcp-core/src/test/java/io/modelcontextprotocol/MockMcpClientTransport.java index 482d0aac6..9854de210 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/MockMcpClientTransport.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/MockMcpClientTransport.java @@ -9,8 +9,8 @@ import java.util.function.BiConsumer; import java.util.function.Function; -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.TypeRef; import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.JSONRPCNotification; @@ -29,6 +29,8 @@ public class MockMcpClientTransport implements McpClientTransport { private final BiConsumer interceptor; + private String protocolVersion = McpSchema.LATEST_PROTOCOL_VERSION; + public MockMcpClientTransport() { this((t, msg) -> { }); @@ -38,6 +40,15 @@ public MockMcpClientTransport(BiConsumer protocolVersions() { + return List.of(protocolVersion); + } + public void simulateIncomingMessage(McpSchema.JSONRPCMessage message) { if (inbound.tryEmitNext(message).isFailure()) { throw new RuntimeException("Failed to process incoming message " + message); @@ -88,8 +99,8 @@ public Mono closeGracefully() { } @Override - public T unmarshalFrom(Object data, TypeReference typeRef) { - return new ObjectMapper().convertValue(data, typeRef); + public T unmarshalFrom(Object data, TypeRef typeRef) { + return McpJsonMapper.getDefault().convertValue(data, typeRef); } } diff --git a/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransport.java b/mcp-core/src/test/java/io/modelcontextprotocol/MockMcpServerTransport.java similarity index 80% rename from mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransport.java rename to mcp-core/src/test/java/io/modelcontextprotocol/MockMcpServerTransport.java index 4be680e11..f3d6b77a7 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransport.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/MockMcpServerTransport.java @@ -8,8 +8,8 @@ import java.util.List; import java.util.function.BiConsumer; -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.TypeRef; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.JSONRPCNotification; import io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest; @@ -53,14 +53,22 @@ public McpSchema.JSONRPCMessage getLastSentMessage() { return !sent.isEmpty() ? sent.get(sent.size() - 1) : null; } + public void clearSentMessages() { + sent.clear(); + } + + public List getAllSentMessages() { + return new ArrayList<>(sent); + } + @Override public Mono closeGracefully() { return Mono.empty(); } @Override - public T unmarshalFrom(Object data, TypeReference typeRef) { - return new ObjectMapper().convertValue(data, typeRef); + public T unmarshalFrom(Object data, TypeRef typeRef) { + return McpJsonMapper.getDefault().convertValue(data, typeRef); } } diff --git a/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java b/mcp-core/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java similarity index 64% rename from mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java rename to mcp-core/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java index 20a8c0cf5..e955be89f 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java @@ -1,21 +1,8 @@ /* -* Copyright 2025 - 2025 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ -package io.modelcontextprotocol; + * Copyright 2025-2025 the original author or authors. + */ -import java.util.Map; +package io.modelcontextprotocol; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpServerSession; diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java new file mode 100644 index 000000000..18a5cb999 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java @@ -0,0 +1,231 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.client; + +import eu.rekawek.toxiproxy.Proxy; +import eu.rekawek.toxiproxy.ToxiproxyClient; +import eu.rekawek.toxiproxy.model.ToxicDirection; +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpTransport; +import io.modelcontextprotocol.spec.McpTransportSessionClosedException; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.Network; +import org.testcontainers.containers.ToxiproxyContainer; +import org.testcontainers.containers.wait.strategy.Wait; +import reactor.test.StepVerifier; + +import java.io.IOException; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.function.Function; + +import static org.assertj.core.api.Assertions.assertThatCode; + +/** + * Resiliency test suite for the {@link McpAsyncClient} that can be used with different + * {@link McpTransport} implementations that support Streamable HTTP. + * + * The purpose of these tests is to allow validating the transport layer resiliency + * instead of the functionality offered by the logical layer of MCP concepts such as + * tools, resources, prompts, etc. + * + * @author Dariusz Jędrzejczyk + */ +// KEEP IN SYNC with the class in mcp-test module +public abstract class AbstractMcpAsyncClientResiliencyTests { + + private static final Logger logger = LoggerFactory.getLogger(AbstractMcpAsyncClientResiliencyTests.class); + + static Network network = Network.newNetwork(); + static String host = "http://localhost:3001"; + + @SuppressWarnings("resource") + static GenericContainer container = new GenericContainer<>("docker.io/node:lts-alpine3.23") + .withCommand("npx -y @modelcontextprotocol/server-everything@2025.12.18 streamableHttp") + .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) + .withNetwork(network) + .withNetworkAliases("everything-server") + .withExposedPorts(3001) + .waitingFor(Wait.forHttp("/").forStatusCode(404)); + + static ToxiproxyContainer toxiproxy = new ToxiproxyContainer("ghcr.io/shopify/toxiproxy:2.5.0").withNetwork(network) + .withExposedPorts(8474, 3000); + + static Proxy proxy; + + static { + container.start(); + + toxiproxy.start(); + + final ToxiproxyClient toxiproxyClient = new ToxiproxyClient(toxiproxy.getHost(), toxiproxy.getControlPort()); + try { + proxy = toxiproxyClient.createProxy("everything-server", "0.0.0.0:3000", "everything-server:3001"); + } + catch (IOException e) { + throw new RuntimeException("Can't create proxy!", e); + } + + final String ipAddressViaToxiproxy = toxiproxy.getHost(); + final int portViaToxiproxy = toxiproxy.getMappedPort(3000); + + host = "http://" + ipAddressViaToxiproxy + ":" + portViaToxiproxy; + } + + static void disconnect() { + long start = System.nanoTime(); + try { + // disconnect + // proxy.toxics().bandwidth("CUT_CONNECTION_DOWNSTREAM", + // ToxicDirection.DOWNSTREAM, 0); + // proxy.toxics().bandwidth("CUT_CONNECTION_UPSTREAM", + // ToxicDirection.UPSTREAM, 0); + proxy.toxics().resetPeer("RESET_DOWNSTREAM", ToxicDirection.DOWNSTREAM, 0); + proxy.toxics().resetPeer("RESET_UPSTREAM", ToxicDirection.UPSTREAM, 0); + logger.info("Disconnect took {} ms", Duration.ofNanos(System.nanoTime() - start).toMillis()); + } + catch (IOException e) { + throw new RuntimeException("Failed to disconnect", e); + } + } + + static void reconnect() { + long start = System.nanoTime(); + try { + proxy.toxics().get("RESET_UPSTREAM").remove(); + proxy.toxics().get("RESET_DOWNSTREAM").remove(); + // proxy.toxics().get("CUT_CONNECTION_DOWNSTREAM").remove(); + // proxy.toxics().get("CUT_CONNECTION_UPSTREAM").remove(); + logger.info("Reconnect took {} ms", Duration.ofNanos(System.nanoTime() - start).toMillis()); + } + catch (IOException e) { + throw new RuntimeException("Failed to reconnect", e); + } + } + + static void restartMcpServer() { + container.stop(); + container.start(); + } + + abstract McpClientTransport createMcpTransport(); + + protected Duration getRequestTimeout() { + return Duration.ofSeconds(14); + } + + protected Duration getInitializationTimeout() { + return Duration.ofSeconds(2); + } + + McpAsyncClient client(McpClientTransport transport) { + return client(transport, Function.identity()); + } + + McpAsyncClient client(McpClientTransport transport, Function customizer) { + AtomicReference client = new AtomicReference<>(); + + assertThatCode(() -> { + // Do not advertise roots. Otherwise, the server will list roots during + // initialization. The client responds asynchronously, and there might be a + // rest condition in tests where we disconnect right after initialization. + McpClient.AsyncSpec builder = McpClient.async(transport) + .requestTimeout(getRequestTimeout()) + .initializationTimeout(getInitializationTimeout()) + .capabilities(McpSchema.ClientCapabilities.builder().build()); + builder = customizer.apply(builder); + client.set(builder.build()); + }).doesNotThrowAnyException(); + + return client.get(); + } + + void withClient(McpClientTransport transport, Consumer c) { + withClient(transport, Function.identity(), c); + } + + void withClient(McpClientTransport transport, Function customizer, + Consumer c) { + var client = client(transport, customizer); + try { + c.accept(client); + } + finally { + StepVerifier.create(client.closeGracefully()).expectComplete().verify(Duration.ofSeconds(10)); + } + } + + @Test + void testPing() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete(); + + disconnect(); + + StepVerifier.create(mcpAsyncClient.ping()).expectError().verify(); + + reconnect(); + + StepVerifier.create(mcpAsyncClient.ping()).expectNextCount(1).verifyComplete(); + }); + } + + @Test + void testSessionInvalidation() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete(); + + restartMcpServer(); + + // The first try will face the session mismatch exception and the second one + // will go through the re-initialization process. + StepVerifier.create(mcpAsyncClient.ping().retry(1)).expectNextCount(1).verifyComplete(); + }); + } + + @Test + void testCallTool() { + withClient(createMcpTransport(), mcpAsyncClient -> { + AtomicReference> tools = new AtomicReference<>(); + StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete(); + StepVerifier.create(mcpAsyncClient.listTools()) + .consumeNextWith(list -> tools.set(list.tools())) + .verifyComplete(); + + disconnect(); + + String name = tools.get().get(0).name(); + // Assuming this is the echo tool + McpSchema.CallToolRequest request = new McpSchema.CallToolRequest(name, Map.of("message", "hello")); + StepVerifier.create(mcpAsyncClient.callTool(request)).expectError().verify(); + + reconnect(); + + StepVerifier.create(mcpAsyncClient.callTool(request)).expectNextCount(1).verifyComplete(); + }); + } + + @Test + void testSessionClose() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete(); + // In case of Streamable HTTP this call should issue a HTTP DELETE request + // invalidating the session + StepVerifier.create(mcpAsyncClient.closeGracefully()).expectComplete().verify(); + // The next tries to use the closed session and fails + StepVerifier.create(mcpAsyncClient.ping()) + .expectErrorMatches(err -> err.getCause() instanceof McpTransportSessionClosedException) + .verify(); + }); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java similarity index 79% rename from mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java index 9be6e553c..5b7877971 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java @@ -4,6 +4,7 @@ package io.modelcontextprotocol.client; +import static io.modelcontextprotocol.util.McpJsonMapperUtils.JSON_MAPPER; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -12,16 +13,16 @@ import java.time.Duration; import java.util.ArrayList; +import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import java.util.function.Function; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; @@ -49,6 +50,7 @@ import io.modelcontextprotocol.spec.McpTransport; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; import reactor.test.StepVerifier; /** @@ -65,18 +67,12 @@ public abstract class AbstractMcpAsyncClientTests { abstract protected McpClientTransport createMcpTransport(); - protected void onStart() { - } - - protected void onClose() { - } - protected Duration getRequestTimeout() { return Duration.ofSeconds(14); } protected Duration getInitializationTimeout() { - return Duration.ofSeconds(2); + return Duration.ofSeconds(20); } McpAsyncClient client(McpClientTransport transport) { @@ -115,16 +111,6 @@ void withClient(McpClientTransport transport, Function void verifyNotificationSucceedsWithImplicitInitialization(Function> operation, String action) { withClient(createMcpTransport(), mcpAsyncClient -> { @@ -183,6 +169,25 @@ void testListAllTools() { }); } + @Test + void testListAllToolsReturnsImmutableList() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listTools())) + .consumeNextWith(result -> { + assertThat(result.tools()).isNotNull(); + // Verify that the returned list is immutable + assertThatThrownBy(() -> result.tools() + .add(Tool.builder() + .name("test") + .title("test") + .inputSchema(JSON_MAPPER, "{\"type\":\"object\"}") + .build())) + .isInstanceOf(UnsupportedOperationException.class); + }) + .verifyComplete(); + }); + } + @Test void testPingWithoutInitialization() { verifyCallSucceedsWithImplicitInitialization(client -> client.ping(), "pinging the server"); @@ -334,6 +339,21 @@ void testListAllResources() { }); } + @Test + void testListAllResourcesReturnsImmutableList() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResources())) + .consumeNextWith(result -> { + assertThat(result.resources()).isNotNull(); + // Verify that the returned list is immutable + assertThatThrownBy( + () -> result.resources().add(Resource.builder().uri("test://uri").name("test").build())) + .isInstanceOf(UnsupportedOperationException.class); + }) + .verifyComplete(); + }); + } + @Test void testMcpAsyncClientState() { withClient(createMcpTransport(), mcpAsyncClient -> { @@ -385,6 +405,20 @@ void testListAllPrompts() { }); } + @Test + void testListAllPromptsReturnsImmutableList() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listPrompts())) + .consumeNextWith(result -> { + assertThat(result.prompts()).isNotNull(); + // Verify that the returned list is immutable + assertThatThrownBy(() -> result.prompts().add(new Prompt("test", "test", "test", null))) + .isInstanceOf(UnsupportedOperationException.class); + }) + .verifyComplete(); + }); + } + @Test void testGetPromptWithoutInitialization() { GetPromptRequest request = new GetPromptRequest("simple_prompt", Map.of()); @@ -441,7 +475,8 @@ void testAddRoot() { void testAddRootWithNullValue() { withClient(createMcpTransport(), mcpAsyncClient -> { StepVerifier.create(mcpAsyncClient.addRoot(null)) - .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class).hasMessage("Root must not be null")) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(IllegalArgumentException.class) + .hasMessage("Root must not be null")) .verify(); }); } @@ -460,7 +495,7 @@ void testRemoveRoot() { void testRemoveNonExistentRoot() { withClient(createMcpTransport(), mcpAsyncClient -> { StepVerifier.create(mcpAsyncClient.removeRoot("nonexistent-uri")) - .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(IllegalStateException.class) .hasMessage("Root with uri 'nonexistent-uri' not found")) .verify(); }); @@ -468,57 +503,64 @@ void testRemoveNonExistentRoot() { @Test void testReadResource() { + AtomicInteger resourceCount = new AtomicInteger(); withClient(createMcpTransport(), client -> { Flux resources = client.initialize() .then(client.listResources(null)) - .flatMapMany(r -> Flux.fromIterable(r.resources())) + .flatMapMany(r -> { + List l = r.resources(); + resourceCount.set(l.size()); + return Flux.fromIterable(l); + }) .flatMap(r -> client.readResource(r)); - StepVerifier.create(resources).recordWith(ArrayList::new).consumeRecordedWith(readResourceResults -> { - - for (ReadResourceResult result : readResourceResults) { - - assertThat(result).isNotNull(); - assertThat(result.contents()).isNotNull().isNotEmpty(); - - // Validate each content item - for (ResourceContents content : result.contents()) { - assertThat(content).isNotNull(); - assertThat(content.uri()).isNotNull().isNotEmpty(); - assertThat(content.mimeType()).isNotNull().isNotEmpty(); - - // Validate content based on its type with more comprehensive - // checks - switch (content.mimeType()) { - case "text/plain" -> { - TextResourceContents textContent = assertInstanceOf(TextResourceContents.class, - content); - assertThat(textContent.text()).isNotNull().isNotEmpty(); - assertThat(textContent.uri()).isNotEmpty(); - } - case "application/octet-stream" -> { - BlobResourceContents blobContent = assertInstanceOf(BlobResourceContents.class, - content); - assertThat(blobContent.blob()).isNotNull().isNotEmpty(); - assertThat(blobContent.uri()).isNotNull().isNotEmpty(); - // Validate base64 encoding format - assertThat(blobContent.blob()).matches("^[A-Za-z0-9+/]*={0,2}$"); - } - default -> { - - // Still validate basic properties - if (content instanceof TextResourceContents textContent) { - assertThat(textContent.text()).isNotNull(); + StepVerifier.create(resources) + .recordWith(ArrayList::new) + .thenConsumeWhile(res -> true) + .consumeRecordedWith(readResourceResults -> { + assertThat(readResourceResults.size()).isEqualTo(resourceCount.get()); + for (ReadResourceResult result : readResourceResults) { + + assertThat(result).isNotNull(); + assertThat(result.contents()).isNotNull().isNotEmpty(); + + // Validate each content item + for (ResourceContents content : result.contents()) { + assertThat(content).isNotNull(); + assertThat(content.uri()).isNotNull().isNotEmpty(); + assertThat(content.mimeType()).isNotNull().isNotEmpty(); + + // Validate content based on its type with more comprehensive + // checks + switch (content.mimeType()) { + case "text/plain" -> { + TextResourceContents textContent = assertInstanceOf(TextResourceContents.class, + content); + assertThat(textContent.text()).isNotNull().isNotEmpty(); + assertThat(textContent.uri()).isNotEmpty(); + } + case "application/octet-stream" -> { + BlobResourceContents blobContent = assertInstanceOf(BlobResourceContents.class, + content); + assertThat(blobContent.blob()).isNotNull().isNotEmpty(); + assertThat(blobContent.uri()).isNotNull().isNotEmpty(); + // Validate base64 encoding format + assertThat(blobContent.blob()).matches("^[A-Za-z0-9+/]*={0,2}$"); } - else if (content instanceof BlobResourceContents blobContent) { - assertThat(blobContent.blob()).isNotNull(); + default -> { + + // Still validate basic properties + if (content instanceof TextResourceContents textContent) { + assertThat(textContent.text()).isNotNull(); + } + else if (content instanceof BlobResourceContents blobContent) { + assertThat(blobContent.blob()).isNotNull(); + } } } } } - } - }) - .expectNextCount(10) // Expect 10 elements + }) .verifyComplete(); }); } @@ -554,6 +596,21 @@ void testListAllResourceTemplates() { }); } + @Test + void testListAllResourceTemplatesReturnsImmutableList() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResourceTemplates())) + .consumeNextWith(result -> { + assertThat(result.resourceTemplates()).isNotNull(); + // Verify that the returned list is immutable + assertThatThrownBy(() -> result.resourceTemplates() + .add(new McpSchema.ResourceTemplate("test://template", "test", "test", null, null, null))) + .isInstanceOf(UnsupportedOperationException.class); + }) + .verifyComplete(); + }); + } + // @Test void testResourceSubscription() { withClient(createMcpTransport(), mcpAsyncClient -> { @@ -623,7 +680,7 @@ void testInitializeWithElicitationCapability() { @Test void testInitializeWithAllCapabilities() { var capabilities = ClientCapabilities.builder() - .experimental(Map.of("feature", "test")) + .experimental(Map.of("feature", Map.of("featureFlag", true))) .roots(true) .sampling() .build(); @@ -643,7 +700,6 @@ void testInitializeWithAllCapabilities() { assertThat(result.capabilities()).isNotNull(); }).verifyComplete()); } - // --------------------------------------- // Logging Tests // --------------------------------------- @@ -723,7 +779,7 @@ void testSampling() { if (!(content instanceof McpSchema.TextContent text)) return; - assertThat(text.text()).endsWith(response); // Prefixed + assertThat(text.text()).contains(response); }); // Verify sampling request parameters received in our callback @@ -735,4 +791,39 @@ void testSampling() { }); } + // --------------------------------------- + // Progress Notification Tests + // --------------------------------------- + + @Test + void testProgressConsumer() { + Sinks.Many sink = Sinks.many().unicast().onBackpressureBuffer(); + List receivedNotifications = new CopyOnWriteArrayList<>(); + + withClient(createMcpTransport(), builder -> builder.progressConsumer(notification -> { + receivedNotifications.add(notification); + sink.tryEmitNext(notification); + return Mono.empty(); + }), client -> { + StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete(); + + // Call a tool that sends progress notifications + CallToolRequest request = CallToolRequest.builder() + .name("longRunningOperation") + .arguments(Map.of("duration", 1, "steps", 2)) + .progressToken("test-token") + .build(); + + StepVerifier.create(client.callTool(request)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + }).verifyComplete(); + + // Use StepVerifier to verify the progress notifications via the sink + StepVerifier.create(sink.asFlux()).expectNextCount(2).thenCancel().verify(Duration.ofSeconds(3)); + + assertThat(receivedNotifications).hasSize(2); + assertThat(receivedNotifications.get(0).progressToken()).isEqualTo("test-token"); + }); + } + } diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java similarity index 92% rename from mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java index 6cb694678..c67fa86bb 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java @@ -13,14 +13,15 @@ import java.time.Duration; import java.util.List; import java.util.Map; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import java.util.function.Function; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; @@ -64,12 +65,6 @@ public abstract class AbstractMcpSyncClientTests { abstract protected McpClientTransport createMcpTransport(); - protected void onStart() { - } - - protected void onClose() { - } - protected Duration getRequestTimeout() { return Duration.ofSeconds(14); } @@ -112,17 +107,6 @@ void withClient(McpClientTransport transport, Function void verifyNotificationSucceedsWithImplicitInitialization(Consumer operation, String action) { @@ -552,11 +536,13 @@ void testNotificationHandlers() { AtomicBoolean toolsNotificationReceived = new AtomicBoolean(false); AtomicBoolean resourcesNotificationReceived = new AtomicBoolean(false); AtomicBoolean promptsNotificationReceived = new AtomicBoolean(false); + AtomicBoolean resourcesUpdatedNotificationReceived = new AtomicBoolean(false); withClient(createMcpTransport(), builder -> builder.toolsChangeConsumer(tools -> toolsNotificationReceived.set(true)) .resourcesChangeConsumer(resources -> resourcesNotificationReceived.set(true)) - .promptsChangeConsumer(prompts -> promptsNotificationReceived.set(true)), + .promptsChangeConsumer(prompts -> promptsNotificationReceived.set(true)) + .resourcesUpdateConsumer(resources -> resourcesUpdatedNotificationReceived.set(true)), client -> { assertThatCode(() -> { @@ -639,7 +625,7 @@ void testSampling() { if (!(content instanceof McpSchema.TextContent text)) return; - assertThat(text.text()).endsWith(response); // Prefixed + assertThat(text.text()).contains(response); }); // Verify sampling request parameters received in our callback @@ -649,4 +635,48 @@ void testSampling() { }); } + // --------------------------------------- + // Progress Notification Tests + // --------------------------------------- + + @Test + void testProgressConsumer() { + AtomicInteger progressNotificationCount = new AtomicInteger(0); + List receivedNotifications = new CopyOnWriteArrayList<>(); + CountDownLatch latch = new CountDownLatch(2); + + withClient(createMcpTransport(), builder -> builder.progressConsumer(notification -> { + System.out.println("Received progress notification: " + notification); + receivedNotifications.add(notification); + progressNotificationCount.incrementAndGet(); + latch.countDown(); + }), client -> { + client.initialize(); + + // Call a tool that sends progress notifications + CallToolRequest request = CallToolRequest.builder() + .name("longRunningOperation") + .arguments(Map.of("duration", 1, "steps", 2)) + .progressToken("test-token") + .build(); + + CallToolResult result = client.callTool(request); + + assertThat(result).isNotNull(); + + try { + // Wait for progress notifications to be processed + latch.await(3, TimeUnit.SECONDS); + } + catch (InterruptedException e) { + e.printStackTrace(); + } + + assertThat(progressNotificationCount.get()).isEqualTo(2); + + assertThat(receivedNotifications).isNotEmpty(); + assertThat(receivedNotifications.get(0).progressToken()).isEqualTo("test-token"); + }); + } + } diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpAsyncClientResiliencyTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpAsyncClientResiliencyTests.java new file mode 100644 index 000000000..945278154 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpAsyncClientResiliencyTests.java @@ -0,0 +1,39 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.client; + +import java.io.IOException; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; +import reactor.test.StepVerifier; + +@Timeout(15) +public class HttpClientStreamableHttpAsyncClientResiliencyTests extends AbstractMcpAsyncClientResiliencyTests { + + @Override + protected McpClientTransport createMcpTransport() { + return HttpClientStreamableHttpTransport.builder(host).build(); + } + + @Test + void testPingWithExactExceptionType() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete(); + + disconnect(); + + StepVerifier.create(mcpAsyncClient.ping()).expectError(IOException.class).verify(); + + reconnect(); + + StepVerifier.create(mcpAsyncClient.ping()).expectNextCount(1).verifyComplete(); + }); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpAsyncClientTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpAsyncClientTests.java new file mode 100644 index 000000000..a29ca16db --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpAsyncClientTests.java @@ -0,0 +1,44 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.client; + +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Timeout; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.wait.strategy.Wait; + +@Timeout(15) +public class HttpClientStreamableHttpAsyncClientTests extends AbstractMcpAsyncClientTests { + + private static String host = "http://localhost:3001"; + + @SuppressWarnings("resource") + static GenericContainer container = new GenericContainer<>("docker.io/node:lts-alpine3.23") + .withCommand("npx -y @modelcontextprotocol/server-everything@2025.12.18 streamableHttp") + .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) + .withExposedPorts(3001) + .waitingFor(Wait.forHttp("/").forStatusCode(404)); + + @Override + protected McpClientTransport createMcpTransport() { + return HttpClientStreamableHttpTransport.builder(host).build(); + } + + @BeforeAll + static void startContainer() { + container.start(); + int port = container.getMappedPort(3001); + host = "http://" + container.getHost() + ":" + port; + } + + @AfterAll + static void stopContainer() { + container.stop(); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpSyncClientTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpSyncClientTests.java new file mode 100644 index 000000000..ee5e5de05 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpSyncClientTests.java @@ -0,0 +1,72 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.client; + +import java.net.URI; +import java.util.Map; + +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.wait.strategy.Wait; + +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.client.transport.customizer.McpSyncHttpClientRequestCustomizer; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.spec.McpClientTransport; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +@Timeout(15) +public class HttpClientStreamableHttpSyncClientTests extends AbstractMcpSyncClientTests { + + static String host = "http://localhost:3001"; + + @SuppressWarnings("resource") + static GenericContainer container = new GenericContainer<>("docker.io/node:lts-alpine3.23") + .withCommand("npx -y @modelcontextprotocol/server-everything@2025.12.18 streamableHttp") + .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) + .withExposedPorts(3001) + .waitingFor(Wait.forHttp("/").forStatusCode(404)); + + private final McpSyncHttpClientRequestCustomizer requestCustomizer = mock(McpSyncHttpClientRequestCustomizer.class); + + @Override + protected McpClientTransport createMcpTransport() { + return HttpClientStreamableHttpTransport.builder(host).httpRequestCustomizer(requestCustomizer).build(); + } + + @BeforeAll + static void startContainer() { + container.start(); + int port = container.getMappedPort(3001); + host = "http://" + container.getHost() + ":" + port; + } + + @AfterAll + static void stopContainer() { + container.stop(); + } + + @Test + void customizesRequests() { + var mcpTransportContext = McpTransportContext.create(Map.of("some-key", "some-value")); + withClient(createMcpTransport(), syncSpec -> syncSpec.transportContextProvider(() -> mcpTransportContext), + mcpSyncClient -> { + mcpSyncClient.initialize(); + + verify(requestCustomizer, atLeastOnce()).customize(any(), eq("POST"), eq(URI.create(host + "/mcp")), + eq("{\"jsonrpc\":\"2.0\",\"method\":\"notifications/initialized\"}"), + eq(mcpTransportContext)); + }); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientLostConnectionTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientLostConnectionTests.java new file mode 100644 index 000000000..e2037f415 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientLostConnectionTests.java @@ -0,0 +1,139 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.client; + +import static org.assertj.core.api.Assertions.assertThatCode; + +import java.io.IOException; +import java.time.Duration; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.Network; +import org.testcontainers.containers.ToxiproxyContainer; +import org.testcontainers.containers.wait.strategy.Wait; + +import eu.rekawek.toxiproxy.Proxy; +import eu.rekawek.toxiproxy.ToxiproxyClient; +import eu.rekawek.toxiproxy.model.ToxicDirection; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpSchema; +import reactor.test.StepVerifier; + +@Timeout(20) +public class HttpSseMcpAsyncClientLostConnectionTests { + + private static final Logger logger = LoggerFactory.getLogger(HttpSseMcpAsyncClientLostConnectionTests.class); + + static Network network = Network.newNetwork(); + static String host = "http://localhost:3001"; + + @SuppressWarnings("resource") + static GenericContainer container = new GenericContainer<>("docker.io/node:lts-alpine3.23") + .withCommand("npx -y @modelcontextprotocol/server-everything@2025.12.18 sse") + .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) + .withNetwork(network) + .withNetworkAliases("everything-server") + .withExposedPorts(3001) + .waitingFor(Wait.forHttp("/").forStatusCode(404)); + + static ToxiproxyContainer toxiproxy = new ToxiproxyContainer("ghcr.io/shopify/toxiproxy:2.5.0").withNetwork(network) + .withExposedPorts(8474, 3000); + + static Proxy proxy; + + static { + container.start(); + + toxiproxy.start(); + + final ToxiproxyClient toxiproxyClient = new ToxiproxyClient(toxiproxy.getHost(), toxiproxy.getControlPort()); + try { + proxy = toxiproxyClient.createProxy("everything-server", "0.0.0.0:3000", "everything-server:3001"); + } + catch (IOException e) { + throw new RuntimeException("Can't create proxy!", e); + } + + final String ipAddressViaToxiproxy = toxiproxy.getHost(); + final int portViaToxiproxy = toxiproxy.getMappedPort(3000); + + host = "http://" + ipAddressViaToxiproxy + ":" + portViaToxiproxy; + } + + static void disconnect() { + long start = System.nanoTime(); + try { + proxy.toxics().resetPeer("RESET_DOWNSTREAM", ToxicDirection.DOWNSTREAM, 0); + proxy.toxics().resetPeer("RESET_UPSTREAM", ToxicDirection.UPSTREAM, 0); + logger.info("Disconnect took {} ms", Duration.ofNanos(System.nanoTime() - start).toMillis()); + } + catch (IOException e) { + throw new RuntimeException("Failed to disconnect", e); + } + } + + static void reconnect() { + long start = System.nanoTime(); + try { + proxy.toxics().get("RESET_UPSTREAM").remove(); + proxy.toxics().get("RESET_DOWNSTREAM").remove(); + logger.info("Reconnect took {} ms", Duration.ofNanos(System.nanoTime() - start).toMillis()); + } + catch (IOException e) { + throw new RuntimeException("Failed to reconnect", e); + } + } + + McpAsyncClient client(McpClientTransport transport) { + AtomicReference client = new AtomicReference<>(); + + assertThatCode(() -> { + // Do not advertise roots. Otherwise, the server will list roots during + // initialization. The client responds asynchronously, and there might be a + // rest condition in tests where we disconnect right after initialization. + McpClient.AsyncSpec builder = McpClient.async(transport) + .requestTimeout(Duration.ofSeconds(14)) + .initializationTimeout(Duration.ofSeconds(2)) + .capabilities(McpSchema.ClientCapabilities.builder().build()); + client.set(builder.build()); + }).doesNotThrowAnyException(); + + return client.get(); + } + + void withClient(McpClientTransport transport, Consumer c) { + var client = client(transport); + try { + c.accept(client); + } + finally { + StepVerifier.create(client.closeGracefully()).expectComplete().verify(Duration.ofSeconds(10)); + } + } + + @Test + void testPingWithExactExceptionType() { + withClient(HttpClientSseClientTransport.builder(host).build(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete(); + + disconnect(); + + // Veryfiy that the exception type is IOException and not TimeoutException + StepVerifier.create(mcpAsyncClient.ping()).expectError(IOException.class).verify(); + + reconnect(); + + StepVerifier.create(mcpAsyncClient.ping()).expectNextCount(1).verifyComplete(); + }); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java similarity index 72% rename from mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java index 1b66a98cd..91a8b6c82 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java @@ -4,12 +4,15 @@ package io.modelcontextprotocol.client; -import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; -import io.modelcontextprotocol.spec.McpClientTransport; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Timeout; import org.testcontainers.containers.GenericContainer; import org.testcontainers.containers.wait.strategy.Wait; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.spec.McpClientTransport; + /** * Tests for the {@link McpSyncClient} with {@link HttpClientSseClientTransport}. * @@ -18,12 +21,11 @@ @Timeout(15) class HttpSseMcpAsyncClientTests extends AbstractMcpAsyncClientTests { - String host = "http://localhost:3004"; + private static String host = "http://localhost:3004"; - // Uses the https://github.com/tzolov/mcp-everything-server-docker-image @SuppressWarnings("resource") - GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") - .withCommand("node dist/index.js sse") + static GenericContainer container = new GenericContainer<>("docker.io/node:lts-alpine3.23") + .withCommand("npx -y @modelcontextprotocol/server-everything@2025.12.18 sse") .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) .withExposedPorts(3001) .waitingFor(Wait.forHttp("/").forStatusCode(404)); @@ -33,15 +35,15 @@ protected McpClientTransport createMcpTransport() { return HttpClientSseClientTransport.builder(host).build(); } - @Override - protected void onStart() { + @BeforeAll + static void startContainer() { container.start(); int port = container.getMappedPort(3001); host = "http://" + container.getHost() + ":" + port; } - @Override - protected void onClose() { + @AfterAll + static void stopContainer() { container.stop(); } diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java new file mode 100644 index 000000000..d903b3b3c --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java @@ -0,0 +1,77 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.client; + +import java.net.URI; +import java.util.Map; + +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.wait.strategy.Wait; + +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.client.transport.customizer.McpSyncHttpClientRequestCustomizer; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.spec.McpClientTransport; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isNull; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +/** + * Tests for the {@link McpSyncClient} with {@link HttpClientSseClientTransport}. + * + * @author Christian Tzolov + */ +@Timeout(15) // Giving extra time beyond the client timeout +class HttpSseMcpSyncClientTests extends AbstractMcpSyncClientTests { + + static String host = "http://localhost:3003"; + + @SuppressWarnings("resource") + static GenericContainer container = new GenericContainer<>("docker.io/node:lts-alpine3.23") + .withCommand("npx -y @modelcontextprotocol/server-everything@2025.12.18 sse") + .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) + .withExposedPorts(3001) + .waitingFor(Wait.forHttp("/").forStatusCode(404)); + + private final McpSyncHttpClientRequestCustomizer requestCustomizer = mock(McpSyncHttpClientRequestCustomizer.class); + + @Override + protected McpClientTransport createMcpTransport() { + return HttpClientSseClientTransport.builder(host).httpRequestCustomizer(requestCustomizer).build(); + } + + @BeforeAll + static void startContainer() { + container.start(); + int port = container.getMappedPort(3001); + host = "http://" + container.getHost() + ":" + port; + } + + @AfterAll + static void stopContainer() { + container.stop(); + } + + @Test + void customizesRequests() { + var mcpTransportContext = McpTransportContext.create(Map.of("some-key", "some-value")); + withClient(createMcpTransport(), syncSpec -> syncSpec.transportContextProvider(() -> mcpTransportContext), + mcpSyncClient -> { + mcpSyncClient.initialize(); + + verify(requestCustomizer, atLeastOnce()).customize(any(), eq("GET"), eq(URI.create(host + "/sse")), + isNull(), eq(mcpTransportContext)); + }); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/client/LifecycleInitializerPostInitializationHookTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/LifecycleInitializerPostInitializationHookTests.java new file mode 100644 index 000000000..6f7390f19 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/LifecycleInitializerPostInitializationHookTests.java @@ -0,0 +1,280 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.client; + +import java.time.Duration; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; + +import io.modelcontextprotocol.client.LifecycleInitializer.Initialization; +import io.modelcontextprotocol.spec.McpClientSession; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; +import reactor.test.StepVerifier; +import reactor.util.context.ContextView; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * Tests for {@link LifecycleInitializer} postInitializationHook functionality. + * + * @author Christian Tzolov + */ +class LifecycleInitializerPostInitializationHookTests { + + private static final Duration INITIALIZATION_TIMEOUT = Duration.ofSeconds(5); + + private static final McpSchema.ClientCapabilities CLIENT_CAPABILITIES = McpSchema.ClientCapabilities.builder() + .build(); + + private static final McpSchema.Implementation CLIENT_INFO = new McpSchema.Implementation("test-client", "1.0.0"); + + private static final List PROTOCOL_VERSIONS = List.of("1.0.0", "2.0.0"); + + private static final McpSchema.InitializeResult MOCK_INIT_RESULT = new McpSchema.InitializeResult("2.0.0", + McpSchema.ServerCapabilities.builder().build(), new McpSchema.Implementation("test-server", "1.0.0"), + "Test instructions"); + + @Mock + private McpClientSession mockClientSession; + + @Mock + private Function mockSessionSupplier; + + @Mock + private Function> mockPostInitializationHook; + + private LifecycleInitializer initializer; + + @BeforeEach + void setUp() { + MockitoAnnotations.openMocks(this); + + when(mockPostInitializationHook.apply(any(Initialization.class))).thenReturn(Mono.empty()); + when(mockSessionSupplier.apply(any(ContextView.class))).thenReturn(mockClientSession); + when(mockClientSession.sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any())) + .thenReturn(Mono.just(MOCK_INIT_RESULT)); + when(mockClientSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_INITIALIZED), any())) + .thenReturn(Mono.empty()); + when(mockClientSession.closeGracefully()).thenReturn(Mono.empty()); + + initializer = new LifecycleInitializer(CLIENT_CAPABILITIES, CLIENT_INFO, PROTOCOL_VERSIONS, + INITIALIZATION_TIMEOUT, mockSessionSupplier, mockPostInitializationHook); + } + + @Test + void shouldInvokePostInitializationHook() { + AtomicReference capturedInit = new AtomicReference<>(); + + when(mockPostInitializationHook.apply(any(Initialization.class))).thenAnswer(invocation -> { + capturedInit.set(invocation.getArgument(0)); + return Mono.empty(); + }); + + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) + .expectNext(MOCK_INIT_RESULT) + .verifyComplete(); + + // Verify hook was called + verify(mockPostInitializationHook, times(1)).apply(any(Initialization.class)); + + // Verify the hook received correct initialization data + assertThat(capturedInit.get()).isNotNull(); + assertThat(capturedInit.get().mcpSession()).isEqualTo(mockClientSession); + assertThat(capturedInit.get().initializeResult()).isEqualTo(MOCK_INIT_RESULT); + } + + @Test + void shouldInvokePostInitializationHookOnlyOnce() { + // First initialization + StepVerifier.create(initializer.withInitialization("test1", init -> Mono.just("result1"))) + .expectNext("result1") + .verifyComplete(); + + // Second call should reuse initialization and NOT call hook again + StepVerifier.create(initializer.withInitialization("test2", init -> Mono.just("result2"))) + .expectNext("result2") + .verifyComplete(); + + // Hook should only be called once + verify(mockPostInitializationHook, times(1)).apply(any(Initialization.class)); + } + + @Test + void shouldInvokePostInitializationHookOnlyOnceWithConcurrentRequests() { + AtomicInteger hookInvocationCount = new AtomicInteger(0); + + when(mockPostInitializationHook.apply(any(Initialization.class))).thenAnswer(invocation -> { + hookInvocationCount.incrementAndGet(); + return Mono.empty(); + }); + + // Start multiple concurrent initializations + Mono init1 = initializer.withInitialization("test1", init -> Mono.just("result1")) + .subscribeOn(Schedulers.parallel()); + Mono init2 = initializer.withInitialization("test2", init -> Mono.just("result2")) + .subscribeOn(Schedulers.parallel()); + Mono init3 = initializer.withInitialization("test3", init -> Mono.just("result3")) + .subscribeOn(Schedulers.parallel()); + + // TODO: can we assume the order of results? + StepVerifier.create(Mono.zip(init1, init2, init3)).assertNext(tuple -> { + assertThat(tuple.getT1()).isEqualTo("result1"); + assertThat(tuple.getT2()).isEqualTo("result2"); + assertThat(tuple.getT3()).isEqualTo("result3"); + }).verifyComplete(); + + // Hook should only be called once despite concurrent requests + assertThat(hookInvocationCount.get()).isEqualTo(1); + } + + @Test + void shouldFailInitializationWhenPostInitializationHookFails() { + RuntimeException hookError = new RuntimeException("Post-initialization hook failed"); + when(mockPostInitializationHook.apply(any(Initialization.class))).thenReturn(Mono.error(hookError)); + + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) + .expectErrorMatches(ex -> ex instanceof RuntimeException && ex.getCause() == hookError) + .verify(); + + // Verify initialization was not completed + assertThat(initializer.isInitialized()).isFalse(); + assertThat(initializer.currentInitializationResult()).isNull(); + + // Verify the hook was called + verify(mockPostInitializationHook, times(1)).apply(any(Initialization.class)); + } + + @Test + void shouldNotInvokePostInitializationHookWhenInitializationFails() { + when(mockClientSession.sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any())) + .thenReturn(Mono.error(new RuntimeException("Initialization failed"))); + + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) + .expectError(RuntimeException.class) + .verify(); + + // Hook should NOT be called when initialization fails + verify(mockPostInitializationHook, never()).apply(any(Initialization.class)); + } + + @Test + void shouldNotInvokePostInitializationHookWhenNotificationFails() { + when(mockClientSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_INITIALIZED), any())) + .thenReturn(Mono.error(new RuntimeException("Notification failed"))); + + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) + .expectError(RuntimeException.class) + .verify(); + + // Hook should NOT be called when notification fails + verify(mockPostInitializationHook, never()).apply(any(Initialization.class)); + } + + @Test + void shouldInvokePostInitializationHookAgainAfterReinitialization() { + AtomicInteger hookInvocationCount = new AtomicInteger(0); + + when(mockPostInitializationHook.apply(any(Initialization.class))).thenAnswer(invocation -> { + hookInvocationCount.incrementAndGet(); + return Mono.empty(); + }); + + // First initialization + StepVerifier.create(initializer.withInitialization("test1", init -> Mono.just("result1"))) + .expectNext("result1") + .verifyComplete(); + + assertThat(hookInvocationCount.get()).isEqualTo(1); + + // Simulate transport session exception to trigger re-initialization + initializer.handleException(new McpTransportSessionNotFoundException("Session lost")); + + // Hook should be called twice (once for each initialization) + assertThat(hookInvocationCount.get()).isEqualTo(2); + } + + @Test + void shouldAllowPostInitializationHookToPerformAsyncOperations() { + AtomicInteger operationCount = new AtomicInteger(0); + + when(mockPostInitializationHook.apply(any(Initialization.class))) + .thenReturn(Mono.fromRunnable(() -> operationCount.incrementAndGet()).then()); + + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) + .expectNext(MOCK_INIT_RESULT) + .verifyComplete(); + + // Verify the async operation was executed + assertThat(operationCount.get()).isEqualTo(1); + verify(mockPostInitializationHook, times(1)).apply(any(Initialization.class)); + } + + @Test + void shouldProvideCorrectInitializationDataToHook() { + AtomicReference capturedSession = new AtomicReference<>(); + AtomicReference capturedResult = new AtomicReference<>(); + + when(mockPostInitializationHook.apply(any(Initialization.class))).thenAnswer(invocation -> { + Initialization init = invocation.getArgument(0); + capturedSession.set(init.mcpSession()); + capturedResult.set(init.initializeResult()); + return Mono.empty(); + }); + + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) + .expectNext(MOCK_INIT_RESULT) + .verifyComplete(); + + // Verify the hook received the correct session and result + assertThat(capturedSession.get()).isEqualTo(mockClientSession); + assertThat(capturedResult.get()).isEqualTo(MOCK_INIT_RESULT); + assertThat(capturedResult.get().protocolVersion()).isEqualTo("2.0.0"); + assertThat(capturedResult.get().serverInfo().name()).isEqualTo("test-server"); + } + + @Test + void shouldInvokePostInitializationHookAfterSuccessfulInitialization() { + AtomicReference notificationSent = new AtomicReference<>(false); + AtomicReference hookCalledAfterNotification = new AtomicReference<>(false); + + when(mockClientSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_INITIALIZED), any())) + .thenAnswer(invocation -> { + notificationSent.set(true); + return Mono.empty(); + }); + + when(mockPostInitializationHook.apply(any(Initialization.class))).thenAnswer(invocation -> { + // Due to flatMap chaining in doInitialize, if the hook is called, + // the notification must have been sent first + hookCalledAfterNotification.set(notificationSent.get()); + return Mono.empty(); + }); + + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) + .expectNext(MOCK_INIT_RESULT) + .verifyComplete(); + + // Verify the hook was called and notification was already sent at that point + assertThat(hookCalledAfterNotification.get()).isTrue(); + verify(mockClientSession).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_INITIALIZED), any()); + verify(mockPostInitializationHook).apply(any(Initialization.class)); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/client/LifecycleInitializerTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/LifecycleInitializerTests.java new file mode 100644 index 000000000..787ee9480 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/LifecycleInitializerTests.java @@ -0,0 +1,427 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.client; + +import java.time.Duration; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; + +import io.modelcontextprotocol.client.LifecycleInitializer.Initialization; +import io.modelcontextprotocol.spec.McpClientSession; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; +import reactor.test.StepVerifier; +import reactor.test.scheduler.VirtualTimeScheduler; +import reactor.util.context.Context; +import reactor.util.context.ContextView; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * Tests for {@link LifecycleInitializer}. + */ +class LifecycleInitializerTests { + + private static final Duration INITIALIZATION_TIMEOUT = Duration.ofSeconds(5); + + private static final McpSchema.ClientCapabilities CLIENT_CAPABILITIES = McpSchema.ClientCapabilities.builder() + .build(); + + private static final McpSchema.Implementation CLIENT_INFO = new McpSchema.Implementation("test-client", "1.0.0"); + + private static final List PROTOCOL_VERSIONS = List.of("1.0.0", "2.0.0"); + + private static final McpSchema.InitializeResult MOCK_INIT_RESULT = new McpSchema.InitializeResult("2.0.0", + McpSchema.ServerCapabilities.builder().build(), new McpSchema.Implementation("test-server", "1.0.0"), + "Test instructions"); + + @Mock + private McpClientSession mockClientSession; + + @Mock + private Function mockSessionSupplier; + + @Mock + private Function> mockPostInitializationHook; + + private LifecycleInitializer initializer; + + @BeforeEach + void setUp() { + MockitoAnnotations.openMocks(this); + + when(mockPostInitializationHook.apply(any(Initialization.class))).thenReturn(Mono.empty()); + when(mockSessionSupplier.apply(any(ContextView.class))).thenReturn(mockClientSession); + when(mockClientSession.sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any())) + .thenReturn(Mono.just(MOCK_INIT_RESULT)); + when(mockClientSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_INITIALIZED), any())) + .thenReturn(Mono.empty()); + when(mockClientSession.closeGracefully()).thenReturn(Mono.empty()); + + initializer = new LifecycleInitializer(CLIENT_CAPABILITIES, CLIENT_INFO, PROTOCOL_VERSIONS, + INITIALIZATION_TIMEOUT, mockSessionSupplier, mockPostInitializationHook); + } + + @Test + void constructorShouldValidateParameters() { + assertThatThrownBy(() -> new LifecycleInitializer(null, CLIENT_INFO, PROTOCOL_VERSIONS, INITIALIZATION_TIMEOUT, + mockSessionSupplier, mockPostInitializationHook)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Client capabilities must not be null"); + + assertThatThrownBy(() -> new LifecycleInitializer(CLIENT_CAPABILITIES, null, PROTOCOL_VERSIONS, + INITIALIZATION_TIMEOUT, mockSessionSupplier, mockPostInitializationHook)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Client info must not be null"); + + assertThatThrownBy(() -> new LifecycleInitializer(CLIENT_CAPABILITIES, CLIENT_INFO, null, + INITIALIZATION_TIMEOUT, mockSessionSupplier, mockPostInitializationHook)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Protocol versions must not be empty"); + + assertThatThrownBy(() -> new LifecycleInitializer(CLIENT_CAPABILITIES, CLIENT_INFO, List.of(), + INITIALIZATION_TIMEOUT, mockSessionSupplier, mockPostInitializationHook)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Protocol versions must not be empty"); + + assertThatThrownBy(() -> new LifecycleInitializer(CLIENT_CAPABILITIES, CLIENT_INFO, PROTOCOL_VERSIONS, null, + mockSessionSupplier, mockPostInitializationHook)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Initialization timeout must not be null"); + + assertThatThrownBy(() -> new LifecycleInitializer(CLIENT_CAPABILITIES, CLIENT_INFO, PROTOCOL_VERSIONS, + INITIALIZATION_TIMEOUT, null, mockPostInitializationHook)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Session supplier must not be null"); + } + + @Test + void shouldInitializeSuccessfully() { + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) + .assertNext(result -> { + assertThat(result).isEqualTo(MOCK_INIT_RESULT); + assertThat(initializer.isInitialized()).isTrue(); + assertThat(initializer.currentInitializationResult()).isEqualTo(MOCK_INIT_RESULT); + }) + .verifyComplete(); + + verify(mockClientSession).sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(McpSchema.InitializeRequest.class), + any()); + verify(mockClientSession).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_INITIALIZED), eq(null)); + } + + @Test + void shouldUseLatestProtocolVersionInInitializeRequest() { + AtomicReference capturedRequest = new AtomicReference<>(); + + when(mockClientSession.sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any())).thenAnswer(invocation -> { + capturedRequest.set((McpSchema.InitializeRequest) invocation.getArgument(1)); + return Mono.just(MOCK_INIT_RESULT); + }); + + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) + .assertNext(result -> { + assertThat(capturedRequest.get().protocolVersion()).isEqualTo("2.0.0"); // Latest + // version + assertThat(capturedRequest.get().capabilities()).isEqualTo(CLIENT_CAPABILITIES); + assertThat(capturedRequest.get().clientInfo()).isEqualTo(CLIENT_INFO); + }) + .verifyComplete(); + } + + @Test + void shouldFailForUnsupportedProtocolVersion() { + McpSchema.InitializeResult unsupportedResult = new McpSchema.InitializeResult("999.0.0", // Unsupported + // version + McpSchema.ServerCapabilities.builder().build(), new McpSchema.Implementation("test-server", "1.0.0"), + "Test instructions"); + + when(mockClientSession.sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any())) + .thenReturn(Mono.just(unsupportedResult)); + + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) + .expectError(RuntimeException.class) + .verify(); + + verify(mockClientSession, never()).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_INITIALIZED), any()); + } + + @Test + void shouldTimeoutOnSlowInitialization() { + VirtualTimeScheduler virtualTimeScheduler = VirtualTimeScheduler.getOrSet(); + + Duration INITIALIZE_TIMEOUT = Duration.ofSeconds(1); + Duration SLOW_RESPONSE_DELAY = Duration.ofSeconds(5); + + LifecycleInitializer shortTimeoutInitializer = new LifecycleInitializer(CLIENT_CAPABILITIES, CLIENT_INFO, + PROTOCOL_VERSIONS, INITIALIZE_TIMEOUT, mockSessionSupplier, mockPostInitializationHook); + + when(mockClientSession.sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any())) + .thenReturn(Mono.just(MOCK_INIT_RESULT).delayElement(SLOW_RESPONSE_DELAY, virtualTimeScheduler)); + + StepVerifier + .withVirtualTime(() -> shortTimeoutInitializer.withInitialization("test", + init -> Mono.just(init.initializeResult())), () -> virtualTimeScheduler, Long.MAX_VALUE) + .expectSubscription() + .expectNoEvent(INITIALIZE_TIMEOUT) + .expectError(RuntimeException.class) + .verify(); + } + + @Test + void shouldReuseExistingInitialization() { + // First initialization + StepVerifier.create(initializer.withInitialization("test1", init -> Mono.just("result1"))) + .expectNext("result1") + .verifyComplete(); + + // Second call should reuse the same initialization + StepVerifier.create(initializer.withInitialization("test2", init -> Mono.just("result2"))) + .expectNext("result2") + .verifyComplete(); + + // Verify session was created only once + verify(mockSessionSupplier, times(1)).apply(any(ContextView.class)); + verify(mockClientSession, times(1)).sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any()); + } + + @Test + void shouldHandleConcurrentInitializationRequests() { + AtomicInteger sessionCreationCount = new AtomicInteger(0); + + when(mockSessionSupplier.apply(any(ContextView.class))).thenAnswer(invocation -> { + sessionCreationCount.incrementAndGet(); + return mockClientSession; + }); + + // Start multiple concurrent initializations using subscribeOn with parallel + // scheduler + Mono init1 = initializer.withInitialization("test1", init -> Mono.just("result1")) + .subscribeOn(Schedulers.parallel()); + Mono init2 = initializer.withInitialization("test2", init -> Mono.just("result2")) + .subscribeOn(Schedulers.parallel()); + Mono init3 = initializer.withInitialization("test3", init -> Mono.just("result3")) + .subscribeOn(Schedulers.parallel()); + + StepVerifier.create(Mono.zip(init1, init2, init3)).assertNext(tuple -> { + assertThat(tuple.getT1()).isEqualTo("result1"); + assertThat(tuple.getT2()).isEqualTo("result2"); + assertThat(tuple.getT3()).isEqualTo("result3"); + }).verifyComplete(); + + // Should only create one session despite concurrent requests + assertThat(sessionCreationCount.get()).isEqualTo(1); + verify(mockClientSession, times(1)).sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any()); + } + + @Test + void shouldHandleInitializationFailure() { + when(mockClientSession.sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any())) + // fail once + .thenReturn(Mono.error(new RuntimeException("Connection failed"))) + // succeeds on the second call + .thenReturn(Mono.just(MOCK_INIT_RESULT)); + + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) + .expectError(RuntimeException.class) + .verify(); + + assertThat(initializer.isInitialized()).isFalse(); + assertThat(initializer.currentInitializationResult()).isNull(); + + // The initializer can recover from previous errors + StepVerifier + .create(initializer.withInitialization("successful init", init -> Mono.just(init.initializeResult()))) + .expectNext(MOCK_INIT_RESULT) + .verifyComplete(); + + assertThat(initializer.isInitialized()).isTrue(); + assertThat(initializer.currentInitializationResult()).isEqualTo(MOCK_INIT_RESULT); + } + + @Test + void shouldHandleTransportSessionNotFoundException() { + // successful initialization first + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) + .expectNext(MOCK_INIT_RESULT) + .verifyComplete(); + + assertThat(initializer.isInitialized()).isTrue(); + + // Simulate transport session not found + initializer.handleException(new McpTransportSessionNotFoundException("Session not found")); + + assertThat(initializer.isInitialized()).isTrue(); + + // Verify that the session was closed and re-initialized + verify(mockClientSession).close(); + + // Verify session was created 2 times (once for initial and once for + // re-initialization) + verify(mockSessionSupplier, times(2)).apply(any(ContextView.class)); + } + + @Test + void shouldHandleOtherExceptions() { + // Simulate a successful initialization first + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) + .expectNext(MOCK_INIT_RESULT) + .verifyComplete(); + + assertThat(initializer.isInitialized()).isTrue(); + + // Simulate other exception (should not trigger re-initialization) + initializer.handleException(new RuntimeException("Some other error")); + + // Should still be initialized + assertThat(initializer.isInitialized()).isTrue(); + verify(mockClientSession, never()).close(); + // Verify that the session was not re-created + verify(mockSessionSupplier, times(1)).apply(any(ContextView.class)); + } + + @Test + void shouldCloseGracefully() { + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) + .expectNext(MOCK_INIT_RESULT) + .verifyComplete(); + + StepVerifier.create(initializer.closeGracefully()).verifyComplete(); + + verify(mockClientSession).closeGracefully(); + assertThat(initializer.isInitialized()).isFalse(); + } + + @Test + void shouldCloseImmediately() { + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) + .expectNext(MOCK_INIT_RESULT) + .verifyComplete(); + + // Close immediately + initializer.close(); + + verify(mockClientSession).close(); + assertThat(initializer.isInitialized()).isFalse(); + } + + @Test + void shouldHandleCloseWithoutInitialization() { + // Close without initialization should not throw + initializer.close(); + + StepVerifier.create(initializer.closeGracefully()).verifyComplete(); + + verify(mockClientSession, never()).close(); + verify(mockClientSession, never()).closeGracefully(); + } + + @Test + void shouldSetProtocolVersionsForTesting() { + List newVersions = List.of("3.0.0", "4.0.0"); + initializer.setProtocolVersions(newVersions); + + AtomicReference capturedRequest = new AtomicReference<>(); + + when(mockClientSession.sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any())).thenAnswer(invocation -> { + capturedRequest.set((McpSchema.InitializeRequest) invocation.getArgument(1)); + return Mono.just(new McpSchema.InitializeResult("4.0.0", McpSchema.ServerCapabilities.builder().build(), + new McpSchema.Implementation("test-server", "1.0.0"), "Test instructions")); + }); + + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) + .assertNext(result -> { + // Latest from new versions + assertThat(capturedRequest.get().protocolVersion()).isEqualTo("4.0.0"); + }) + .verifyComplete(); + } + + @Test + void shouldPassContextToSessionSupplier() { + String contextKey = "test.key"; + String contextValue = "test.value"; + + AtomicReference capturedContext = new AtomicReference<>(); + + when(mockSessionSupplier.apply(any(ContextView.class))).thenAnswer(invocation -> { + capturedContext.set(invocation.getArgument(0)); + return mockClientSession; + }); + + StepVerifier + .create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult())) + .contextWrite(Context.of(contextKey, contextValue))) + .expectNext(MOCK_INIT_RESULT) + .verifyComplete(); + + assertThat(capturedContext.get().hasKey(contextKey)).isTrue(); + assertThat((String) capturedContext.get().get(contextKey)).isEqualTo(contextValue); + } + + @Test + void shouldProvideAccessToMcpSessionAndInitializeResult() { + StepVerifier.create(initializer.withInitialization("test", init -> { + assertThat(init.mcpSession()).isEqualTo(mockClientSession); + assertThat(init.initializeResult()).isEqualTo(MOCK_INIT_RESULT); + return Mono.just("success"); + })).expectNext("success").verifyComplete(); + } + + @Test + void shouldHandleNotificationFailure() { + when(mockClientSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_INITIALIZED), any())) + .thenReturn(Mono.error(new RuntimeException("Notification failed"))); + + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) + .expectError(RuntimeException.class) + .verify(); + + verify(mockClientSession).sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any()); + verify(mockClientSession).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_INITIALIZED), eq(null)); + } + + @Test + void shouldReturnNullWhenNotInitialized() { + assertThat(initializer.isInitialized()).isFalse(); + assertThat(initializer.currentInitializationResult()).isNull(); + } + + @Test + void shouldReinitializeAfterTransportSessionException() { + // First initialization + StepVerifier.create(initializer.withInitialization("test1", init -> Mono.just("result1"))) + .expectNext("result1") + .verifyComplete(); + + // Simulate transport session exception + initializer.handleException(new McpTransportSessionNotFoundException("Session lost")); + + // Should be able to initialize again + StepVerifier.create(initializer.withInitialization("test2", init -> Mono.just("result2"))) + .expectNext("result2") + .verifyComplete(); + + // Verify two separate initializations occurred + verify(mockSessionSupplier, times(2)).apply(any(ContextView.class)); + verify(mockClientSession, times(2)).sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any()); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java similarity index 91% rename from mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java index a79bdf6c9..612a65898 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java @@ -4,16 +4,14 @@ package io.modelcontextprotocol.client; +import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.function.Function; -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.json.TypeRef; import io.modelcontextprotocol.MockMcpClientTransport; -import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; import io.modelcontextprotocol.spec.McpSchema.InitializeResult; @@ -25,6 +23,7 @@ import reactor.core.publisher.Mono; import static io.modelcontextprotocol.spec.McpSchema.METHOD_INITIALIZE; +import static io.modelcontextprotocol.util.McpJsonMapperUtils.JSON_MAPPER; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -52,7 +51,7 @@ private static MockMcpClientTransport initializationEnabledTransport( r.id(), mockInitResult, null); t.simulateIncomingMessage(initResponse); } - }); + }).withProtocolVersion(McpSchema.LATEST_PROTOCOL_VERSION); } @Test @@ -79,8 +78,9 @@ void testSuccessfulInitialization() { // Verify initialization result assertThat(result).isNotNull(); - assertThat(result.protocolVersion()).isEqualTo(McpSchema.LATEST_PROTOCOL_VERSION); + assertThat(result.protocolVersion()).isEqualTo(transport.protocolVersions().get(0)); assertThat(result.capabilities()).isEqualTo(serverCapabilities); + assertThat(result.capabilities().logging()).isNull(); assertThat(result.serverInfo()).isEqualTo(serverInfo); assertThat(result.instructions()).isEqualTo("Test instructions"); @@ -93,7 +93,7 @@ void testSuccessfulInitialization() { } @Test - void testToolsChangeNotificationHandling() throws JsonProcessingException { + void testToolsChangeNotificationHandling() throws IOException { MockMcpClientTransport transport = initializationEnabledTransport(); // Create a list to store received tools for verification @@ -110,8 +110,11 @@ void testToolsChangeNotificationHandling() throws JsonProcessingException { // Create a mock tools list that the server will return Map inputSchema = Map.of("type", "object", "properties", Map.of(), "required", List.of()); - McpSchema.Tool mockTool = new McpSchema.Tool("test-tool-1", "Test Tool 1 Description", - new ObjectMapper().writeValueAsString(inputSchema)); + McpSchema.Tool mockTool = McpSchema.Tool.builder() + .name("test-tool-1") + .description("Test Tool 1 Description") + .inputSchema(JSON_MAPPER, JSON_MAPPER.writeValueAsString(inputSchema)) + .build(); // Create page 1 response with nextPageToken String nextPageToken = "page2Token"; @@ -131,9 +134,11 @@ void testToolsChangeNotificationHandling() throws JsonProcessingException { transport.simulateIncomingMessage(toolsListResponse1); // Create mock tools for page 2 - McpSchema.Tool mockTool2 = new McpSchema.Tool("test-tool-2", "Test Tool 2 Description", - new ObjectMapper().writeValueAsString(inputSchema)); - + McpSchema.Tool mockTool2 = McpSchema.Tool.builder() + .name("test-tool-2") + .description("Test Tool 2 Description") + .inputSchema(JSON_MAPPER, JSON_MAPPER.writeValueAsString(inputSchema)) + .build(); // Create page 2 response with no nextPageToken (last page) McpSchema.ListToolsResult mockToolsResult2 = new McpSchema.ListToolsResult(List.of(mockTool2), null); @@ -251,8 +256,8 @@ void testPromptsChangeNotificationHandling() { assertThat(asyncMcpClient.initialize().block()).isNotNull(); // Create a mock prompts list that the server will return - McpSchema.Prompt mockPrompt = new McpSchema.Prompt("test-prompt", "Test Prompt Description", - List.of(new McpSchema.PromptArgument("arg1", "Test argument", true))); + McpSchema.Prompt mockPrompt = new McpSchema.Prompt("test-prompt", "Test Prompt", "Test Prompt Description", + List.of(new McpSchema.PromptArgument("arg1", "Test argument", "Test argument", true))); McpSchema.ListPromptsResult mockPromptsResult = new McpSchema.ListPromptsResult(List.of(mockPrompt), null); // Simulate server sending prompts/list_changed notification @@ -321,7 +326,7 @@ void testSamplingCreateMessageRequestHandling() { assertThat(response.error()).isNull(); McpSchema.CreateMessageResult result = transport.unmarshalFrom(response.result(), - new TypeReference() { + new TypeRef() { }); assertThat(result).isNotNull(); assertThat(result.role()).isEqualTo(McpSchema.Role.ASSISTANT); @@ -373,7 +378,7 @@ void testSamplingCreateMessageRequestHandlingWithNullHandler() { // Create client with sampling capability but null handler assertThatThrownBy( () -> McpClient.async(transport).capabilities(ClientCapabilities.builder().sampling().build()).build()) - .isInstanceOf(McpError.class) + .isInstanceOf(IllegalArgumentException.class) .hasMessage("Sampling handler must not be null when client capabilities include sampling"); } @@ -425,7 +430,7 @@ void testElicitationCreateRequestHandling() { assertThat(response.id()).isEqualTo("test-id"); assertThat(response.error()).isNull(); - McpSchema.ElicitResult result = transport.unmarshalFrom(response.result(), new TypeReference<>() { + McpSchema.ElicitResult result = transport.unmarshalFrom(response.result(), new TypeRef<>() { }); assertThat(result).isNotNull(); assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); @@ -470,7 +475,7 @@ void testElicitationFailRequestHandling(McpSchema.ElicitResult.Action action) { assertThat(response.id()).isEqualTo("test-id"); assertThat(response.error()).isNull(); - McpSchema.ElicitResult result = transport.unmarshalFrom(response.result(), new TypeReference<>() { + McpSchema.ElicitResult result = transport.unmarshalFrom(response.result(), new TypeRef<>() { }); assertThat(result).isNotNull(); assertThat(result.action()).isEqualTo(action); @@ -521,8 +526,34 @@ void testElicitationCreateRequestHandlingWithNullHandler() { // Create client with elicitation capability but null handler assertThatThrownBy(() -> McpClient.async(transport) .capabilities(ClientCapabilities.builder().elicitation().build()) - .build()).isInstanceOf(McpError.class) + .build()).isInstanceOf(IllegalArgumentException.class) .hasMessage("Elicitation handler must not be null when client capabilities include elicitation"); } + @Test + void testPingMessageRequestHandling() { + MockMcpClientTransport transport = initializationEnabledTransport(); + + McpAsyncClient asyncMcpClient = McpClient.async(transport).build(); + + assertThat(asyncMcpClient.initialize().block()).isNotNull(); + + // Simulate incoming ping request from server + McpSchema.JSONRPCRequest pingRequest = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, + McpSchema.METHOD_PING, "ping-id", null); + transport.simulateIncomingMessage(pingRequest); + + // Verify response + McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessage(); + assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class); + + McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; + assertThat(response.id()).isEqualTo("ping-id"); + assertThat(response.error()).isNull(); + assertThat(response.result()).isInstanceOf(Map.class); + assertThat(((Map) response.result())).isEmpty(); + + asyncMcpClient.closeGracefully(); + } + } diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/client/McpAsyncClientTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/McpAsyncClientTests.java new file mode 100644 index 000000000..48bf1da5b --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/McpAsyncClientTests.java @@ -0,0 +1,310 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.client; + +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; +import java.util.stream.Collectors; + +import io.modelcontextprotocol.json.TypeRef; +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.ProtocolVersions; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import static io.modelcontextprotocol.util.McpJsonMapperUtils.JSON_MAPPER; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; + +class McpAsyncClientTests { + + public static final McpSchema.Implementation MOCK_SERVER_INFO = new McpSchema.Implementation("test-server", + "1.0.0"); + + public static final McpSchema.ServerCapabilities MOCK_SERVER_CAPABILITIES = McpSchema.ServerCapabilities.builder() + .tools(true) + .build(); + + public static final McpSchema.InitializeResult MOCK_INIT_RESULT = new McpSchema.InitializeResult( + ProtocolVersions.MCP_2024_11_05, MOCK_SERVER_CAPABILITIES, MOCK_SERVER_INFO, "Test instructions"); + + private static final String CONTEXT_KEY = "context.key"; + + private McpClientTransport createMockTransportForToolValidation(boolean hasOutputSchema, boolean invalidOutput) { + + // Create tool with or without output schema + Map inputSchemaMap = Map.of("type", "object", "properties", + Map.of("expression", Map.of("type", "string")), "required", List.of("expression")); + + McpSchema.JsonSchema inputSchema = new McpSchema.JsonSchema("object", inputSchemaMap, null, null, null, null); + McpSchema.Tool.Builder toolBuilder = McpSchema.Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .inputSchema(inputSchema); + + if (hasOutputSchema) { + Map outputSchema = Map.of("type", "object", "properties", + Map.of("result", Map.of("type", "number"), "operation", Map.of("type", "string")), "required", + List.of("result", "operation")); + toolBuilder.outputSchema(outputSchema); + } + + McpSchema.Tool calculatorTool = toolBuilder.build(); + McpSchema.ListToolsResult mockToolsResult = new McpSchema.ListToolsResult(List.of(calculatorTool), null); + + // Create call tool result - valid or invalid based on parameter + Map structuredContent = invalidOutput ? Map.of("result", "5", "operation", "add") + : Map.of("result", 5, "operation", "add"); + + McpSchema.CallToolResult mockCallToolResult = McpSchema.CallToolResult.builder() + .addTextContent("Calculation result") + .structuredContent(structuredContent) + .build(); + + return new McpClientTransport() { + Function, Mono> handler; + + @Override + public Mono connect( + Function, Mono> handler) { + this.handler = handler; + return Mono.empty(); + } + + @Override + public Mono closeGracefully() { + return Mono.empty(); + } + + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message) { + if (!(message instanceof McpSchema.JSONRPCRequest request)) { + return Mono.empty(); + } + + McpSchema.JSONRPCResponse response; + if (McpSchema.METHOD_INITIALIZE.equals(request.method())) { + response = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), MOCK_INIT_RESULT, + null); + } + else if (McpSchema.METHOD_TOOLS_LIST.equals(request.method())) { + response = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), mockToolsResult, + null); + } + else if (McpSchema.METHOD_TOOLS_CALL.equals(request.method())) { + response = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), + mockCallToolResult, null); + } + else { + return Mono.empty(); + } + + return handler.apply(Mono.just(response)).then(); + } + + @Override + public T unmarshalFrom(Object data, TypeRef typeRef) { + return JSON_MAPPER.convertValue(data, new TypeRef<>() { + @Override + public java.lang.reflect.Type getType() { + return typeRef.getType(); + } + }); + } + }; + } + + @Test + void validateContextPassedToTransportConnect() { + McpClientTransport transport = new McpClientTransport() { + Function, Mono> handler; + + final AtomicReference contextValue = new AtomicReference<>(); + + @Override + public Mono connect( + Function, Mono> handler) { + return Mono.deferContextual(ctx -> { + this.handler = handler; + if (ctx.hasKey(CONTEXT_KEY)) { + this.contextValue.set(ctx.get(CONTEXT_KEY)); + } + return Mono.empty(); + }); + } + + @Override + public Mono closeGracefully() { + return Mono.empty(); + } + + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message) { + if (!"hello".equals(this.contextValue.get())) { + return Mono.error(new RuntimeException("Context value not propagated via #connect method")); + } + // We're only interested in handling the init request to provide an init + // response + if (!(message instanceof McpSchema.JSONRPCRequest)) { + return Mono.empty(); + } + McpSchema.JSONRPCResponse initResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, + ((McpSchema.JSONRPCRequest) message).id(), MOCK_INIT_RESULT, null); + return handler.apply(Mono.just(initResponse)).then(); + } + + @Override + public T unmarshalFrom(Object data, TypeRef typeRef) { + return JSON_MAPPER.convertValue(data, new TypeRef<>() { + @Override + public java.lang.reflect.Type getType() { + return typeRef.getType(); + } + }); + } + }; + + assertThatCode(() -> { + McpAsyncClient client = McpClient.async(transport).build(); + client.initialize().contextWrite(ctx -> ctx.put(CONTEXT_KEY, "hello")).block(); + }).doesNotThrowAnyException(); + } + + @Test + void testCallToolWithOutputSchemaValidationSuccess() { + McpClientTransport transport = createMockTransportForToolValidation(true, false); + + McpAsyncClient client = McpClient.async(transport).enableCallToolSchemaCaching(true).build(); + + StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete(); + + StepVerifier.create(client.callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3")))) + .expectNextMatches(response -> { + assertThat(response).isNotNull(); + assertThat(response.isError()).isFalse(); + assertThat(response.structuredContent()).isInstanceOf(Map.class); + assertThat((Map) response.structuredContent()).hasSize(2); + assertThat(response.content()).hasSize(1); + return true; + }) + .verifyComplete(); + + StepVerifier.create(client.closeGracefully()).verifyComplete(); + } + + @Test + void testCallToolWithNoOutputSchemaSuccess() { + McpClientTransport transport = createMockTransportForToolValidation(false, false); + + McpAsyncClient client = McpClient.async(transport).enableCallToolSchemaCaching(true).build(); + + StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete(); + + StepVerifier.create(client.callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3")))) + .expectNextMatches(response -> { + assertThat(response).isNotNull(); + assertThat(response.isError()).isFalse(); + assertThat(response.structuredContent()).isInstanceOf(Map.class); + assertThat((Map) response.structuredContent()).hasSize(2); + assertThat(response.content()).hasSize(1); + return true; + }) + .verifyComplete(); + + StepVerifier.create(client.closeGracefully()).verifyComplete(); + } + + @Test + void testCallToolWithOutputSchemaValidationFailure() { + McpClientTransport transport = createMockTransportForToolValidation(true, true); + + McpAsyncClient client = McpClient.async(transport).enableCallToolSchemaCaching(true).build(); + + StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete(); + + StepVerifier.create(client.callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3")))) + .expectErrorMatches(ex -> ex instanceof IllegalArgumentException + && ex.getMessage().contains("Tool call result validation failed")) + .verify(); + + StepVerifier.create(client.closeGracefully()).verifyComplete(); + } + + @Test + void testListToolsWithEmptyCursor() { + McpSchema.Tool addTool = McpSchema.Tool.builder().name("add").description("calculate add").build(); + McpSchema.Tool subtractTool = McpSchema.Tool.builder() + .name("subtract") + .description("calculate subtract") + .build(); + McpSchema.ListToolsResult mockToolsResult = new McpSchema.ListToolsResult(List.of(addTool, subtractTool), ""); + + McpClientTransport transport = new McpClientTransport() { + Function, Mono> handler; + + @Override + public Mono connect( + Function, Mono> handler) { + return Mono.deferContextual(ctx -> { + this.handler = handler; + return Mono.empty(); + }); + } + + @Override + public Mono closeGracefully() { + return Mono.empty(); + } + + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message) { + if (!(message instanceof McpSchema.JSONRPCRequest request)) { + return Mono.empty(); + } + + McpSchema.JSONRPCResponse response; + if (McpSchema.METHOD_INITIALIZE.equals(request.method())) { + response = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), MOCK_INIT_RESULT, + null); + } + else if (McpSchema.METHOD_TOOLS_LIST.equals(request.method())) { + response = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), mockToolsResult, + null); + } + else { + return Mono.empty(); + } + + return handler.apply(Mono.just(response)).then(); + } + + @Override + public T unmarshalFrom(Object data, TypeRef typeRef) { + return JSON_MAPPER.convertValue(data, new TypeRef<>() { + @Override + public java.lang.reflect.Type getType() { + return typeRef.getType(); + } + }); + } + }; + + McpAsyncClient client = McpClient.async(transport).enableCallToolSchemaCaching(true).build(); + + Mono mono = client.listTools(); + McpSchema.ListToolsResult toolsResult = mono.block(); + assertThat(toolsResult).isNotNull(); + + Set names = toolsResult.tools().stream().map(McpSchema.Tool::name).collect(Collectors.toSet()); + assertThat(names).containsExactlyInAnyOrder("subtract", "add"); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/McpClientProtocolVersionTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/McpClientProtocolVersionTests.java similarity index 87% rename from mcp/src/test/java/io/modelcontextprotocol/client/McpClientProtocolVersionTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/client/McpClientProtocolVersionTests.java index bf4738496..a94b9b6a7 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/McpClientProtocolVersionTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/McpClientProtocolVersionTests.java @@ -8,9 +8,9 @@ import java.util.List; import io.modelcontextprotocol.MockMcpClientTransport; -import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.InitializeResult; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; import org.junit.jupiter.api.Test; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; @@ -22,7 +22,7 @@ */ class McpClientProtocolVersionTests { - private static final Duration REQUEST_TIMEOUT = Duration.ofSeconds(30); + private static final Duration REQUEST_TIMEOUT = Duration.ofSeconds(300); private static final McpSchema.Implementation CLIENT_INFO = new McpSchema.Implementation("test-client", "1.0.0"); @@ -37,20 +37,21 @@ void shouldUseLatestVersionByDefault() { try { Mono initializeResultMono = client.initialize(); + String protocolVersion = transport.protocolVersions().get(transport.protocolVersions().size() - 1); + StepVerifier.create(initializeResultMono).then(() -> { McpSchema.JSONRPCRequest request = transport.getLastSentMessageAsRequest(); assertThat(request.params()).isInstanceOf(McpSchema.InitializeRequest.class); McpSchema.InitializeRequest initRequest = (McpSchema.InitializeRequest) request.params(); - assertThat(initRequest.protocolVersion()).isEqualTo(McpSchema.LATEST_PROTOCOL_VERSION); + assertThat(initRequest.protocolVersion()).isEqualTo(transport.protocolVersions().get(0)); transport.simulateIncomingMessage(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), - new McpSchema.InitializeResult(McpSchema.LATEST_PROTOCOL_VERSION, null, + new McpSchema.InitializeResult(protocolVersion, ServerCapabilities.builder().build(), new McpSchema.Implementation("test-server", "1.0.0"), null), null)); }).assertNext(result -> { - assertThat(result.protocolVersion()).isEqualTo(McpSchema.LATEST_PROTOCOL_VERSION); + assertThat(result.protocolVersion()).isEqualTo(protocolVersion); }).verifyComplete(); - } finally { // Ensure cleanup happens even if test fails @@ -79,7 +80,7 @@ void shouldNegotiateSpecificVersion() { assertThat(initRequest.protocolVersion()).isIn(List.of(oldVersion, McpSchema.LATEST_PROTOCOL_VERSION)); transport.simulateIncomingMessage(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), - new McpSchema.InitializeResult(oldVersion, null, + new McpSchema.InitializeResult(oldVersion, ServerCapabilities.builder().build(), new McpSchema.Implementation("test-server", "1.0.0"), null), null)); }).assertNext(result -> { @@ -108,10 +109,10 @@ void shouldFailForUnsupportedVersion() { assertThat(request.params()).isInstanceOf(McpSchema.InitializeRequest.class); transport.simulateIncomingMessage(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), - new McpSchema.InitializeResult(unsupportedVersion, null, + new McpSchema.InitializeResult(unsupportedVersion, ServerCapabilities.builder().build(), new McpSchema.Implementation("test-server", "1.0.0"), null), null)); - }).expectError(McpError.class).verify(); + }).expectError(RuntimeException.class).verify(); } finally { StepVerifier.create(client.closeGracefully()).verifyComplete(); @@ -141,7 +142,7 @@ void shouldUseHighestVersionWhenMultipleSupported() { assertThat(initRequest.protocolVersion()).isEqualTo(latestVersion); transport.simulateIncomingMessage(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), - new McpSchema.InitializeResult(latestVersion, null, + new McpSchema.InitializeResult(latestVersion, ServerCapabilities.builder().build(), new McpSchema.Implementation("test-server", "1.0.0"), null), null)); }).assertNext(result -> { diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/client/ServerParameterUtils.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/ServerParameterUtils.java new file mode 100644 index 000000000..547ccc52f --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/ServerParameterUtils.java @@ -0,0 +1,21 @@ +package io.modelcontextprotocol.client; + +import io.modelcontextprotocol.client.transport.ServerParameters; + +public final class ServerParameterUtils { + + private ServerParameterUtils() { + } + + public static ServerParameters createServerParameters() { + if (System.getProperty("os.name").toLowerCase().contains("win")) { + return ServerParameters.builder("cmd.exe") + .args("/c", "npx.cmd", "-y", "@modelcontextprotocol/server-everything@2025.12.18", "stdio") + .build(); + } + return ServerParameters.builder("npx") + .args("-y", "@modelcontextprotocol/server-everything@2025.12.18", "stdio") + .build(); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java new file mode 100644 index 000000000..aa8aaa397 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java @@ -0,0 +1,47 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.client; + +import java.time.Duration; + +import io.modelcontextprotocol.client.transport.ServerParameters; +import io.modelcontextprotocol.client.transport.StdioClientTransport; +import io.modelcontextprotocol.spec.McpClientTransport; +import org.junit.jupiter.api.Timeout; + +import static io.modelcontextprotocol.client.ServerParameterUtils.createServerParameters; +import static io.modelcontextprotocol.util.McpJsonMapperUtils.JSON_MAPPER; + +/** + * Tests for the {@link McpAsyncClient} with {@link StdioClientTransport}. + * + *

+ * These tests use npx to download and run the MCP "everything" server locally. The first + * test execution will download the everything server scripts and cache them locally, + * which can take more than 15 seconds. Subsequent test runs will use the cached version + * and execute faster. + * + * @author Christian Tzolov + * @author Dariusz Jędrzejczyk + */ +@Timeout(25) // Giving extra time beyond the client timeout to account for initial server + // download +class StdioMcpAsyncClientTests extends AbstractMcpAsyncClientTests { + + @Override + protected McpClientTransport createMcpTransport() { + return new StdioClientTransport(createServerParameters(), JSON_MAPPER); + } + + protected Duration getInitializationTimeout() { + return Duration.ofSeconds(20); + } + + @Override + protected Duration getRequestTimeout() { + return Duration.ofSeconds(25); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java similarity index 70% rename from mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java index 4b5f4f9c0..b1e567989 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java @@ -17,31 +17,30 @@ import reactor.core.publisher.Sinks; import reactor.test.StepVerifier; +import static io.modelcontextprotocol.client.ServerParameterUtils.createServerParameters; +import static io.modelcontextprotocol.util.McpJsonMapperUtils.JSON_MAPPER; import static org.assertj.core.api.Assertions.assertThat; /** * Tests for the {@link McpSyncClient} with {@link StdioClientTransport}. * + *

+ * These tests use npx to download and run the MCP "everything" server locally. The first + * test execution will download the everything server scripts and cache them locally, + * which can take more than 15 seconds. Subsequent test runs will use the cached version + * and execute faster. + * * @author Christian Tzolov * @author Dariusz Jędrzejczyk */ -@Timeout(15) // Giving extra time beyond the client timeout +@Timeout(25) // Giving extra time beyond the client timeout to account for initial server + // download class StdioMcpSyncClientTests extends AbstractMcpSyncClientTests { @Override protected McpClientTransport createMcpTransport() { - ServerParameters stdioParams; - if (System.getProperty("os.name").toLowerCase().contains("win")) { - stdioParams = ServerParameters.builder("cmd.exe") - .args("/c", "npx.cmd", "-y", "@modelcontextprotocol/server-everything", "stdio") - .build(); - } - else { - stdioParams = ServerParameters.builder("npx") - .args("-y", "@modelcontextprotocol/server-everything", "stdio") - .build(); - } - return new StdioClientTransport(stdioParams); + ServerParameters stdioParams = createServerParameters(); + return new StdioClientTransport(stdioParams, JSON_MAPPER); } @Test @@ -71,4 +70,9 @@ protected Duration getInitializationTimeout() { return Duration.ofSeconds(10); } + @Override + protected Duration getRequestTimeout() { + return Duration.ofSeconds(25); + } + } diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java similarity index 68% rename from mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java index 1b1c72012..a24805a30 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java @@ -7,23 +7,26 @@ import java.net.URI; import java.net.http.HttpClient; import java.net.http.HttpRequest; -import java.net.http.HttpResponse; import java.time.Duration; import java.util.Map; -import java.util.concurrent.CompletableFuture; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; +import io.modelcontextprotocol.client.transport.customizer.McpAsyncHttpClientRequestCustomizer; +import io.modelcontextprotocol.client.transport.customizer.McpSyncHttpClientRequestCustomizer; +import io.modelcontextprotocol.common.McpTransportContext; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest; + +import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; import org.mockito.ArgumentCaptor; -import org.mockito.Mockito; import org.testcontainers.containers.GenericContainer; import org.testcontainers.containers.wait.strategy.Wait; import reactor.core.publisher.Mono; @@ -31,15 +34,19 @@ import reactor.test.StepVerifier; import org.springframework.http.codec.ServerSentEvent; +import org.springframework.web.util.UriComponentsBuilder; +import static io.modelcontextprotocol.util.McpJsonMapperUtils.JSON_MAPPER; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isNull; +import static org.mockito.Mockito.clearInvocations; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import com.fasterxml.jackson.databind.ObjectMapper; - /** * Tests for the {@link HttpClientSseClientTransport} class. * @@ -51,14 +58,16 @@ class HttpClientSseClientTransportTests { static String host = "http://localhost:3001"; @SuppressWarnings("resource") - GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") - .withCommand("node dist/index.js sse") + static GenericContainer container = new GenericContainer<>("docker.io/node:lts-alpine3.23") + .withCommand("npx -y @modelcontextprotocol/server-everything@2025.12.18 sse") .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) .withExposedPorts(3001) .waitingFor(Wait.forHttp("/").forStatusCode(404)); private TestHttpClientSseClientTransport transport; + private final McpTransportContext context = McpTransportContext.create(Map.of("some-key", "some-value")); + // Test class to access protected methods static class TestHttpClientSseClientTransport extends HttpClientSseClientTransport { @@ -67,7 +76,9 @@ static class TestHttpClientSseClientTransport extends HttpClientSseClientTranspo private Sinks.Many> events = Sinks.many().unicast().onBackpressureBuffer(); public TestHttpClientSseClientTransport(final String baseUri) { - super(HttpClient.newHttpClient(), HttpRequest.newBuilder(), baseUri, "/sse", new ObjectMapper()); + super(HttpClient.newBuilder().version(HttpClient.Version.HTTP_1_1).build(), + HttpRequest.newBuilder().header("Content-Type", "application/json"), baseUri, "/sse", JSON_MAPPER, + McpAsyncHttpClientRequestCustomizer.NOOP); } public int getInboundMessageCount() { @@ -86,15 +97,21 @@ public void simulateMessageEvent(String jsonMessage) { } - void startContainer() { + @BeforeAll + static void startContainer() { container.start(); int port = container.getMappedPort(3001); host = "http://" + container.getHost() + ":" + port; + + } + + @AfterAll + static void stopContainer() { + container.stop(); } @BeforeEach void setUp() { - startContainer(); transport = new TestHttpClientSseClientTransport(host); transport.connect(Function.identity()).block(); } @@ -104,11 +121,16 @@ void afterEach() { if (transport != null) { assertThatCode(() -> transport.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); } - cleanup(); } - void cleanup() { - container.stop(); + @Test + void testErrorOnBogusMessage() { + // bogus message + JSONRPCRequest bogusMessage = new JSONRPCRequest(null, null, "test-id", Map.of("key", "value")); + + StepVerifier.create(transport.sendMessage(bogusMessage)) + .verifyErrorMessage( + "Sending message failed with a non-OK HTTP code: 400 - Invalid message: {\"id\":\"test-id\",\"params\":{\"key\":\"value\"}}"); } @Test @@ -372,24 +394,87 @@ void testChainedCustomizations() { } @Test - @SuppressWarnings("unchecked") - void testResolvingClientEndpoint() { - HttpClient httpClient = Mockito.mock(HttpClient.class); - HttpResponse httpResponse = Mockito.mock(HttpResponse.class); - CompletableFuture> future = new CompletableFuture<>(); - future.complete(httpResponse); - when(httpClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))).thenReturn(future); + void testRequestCustomizer() { + var mockCustomizer = mock(McpSyncHttpClientRequestCustomizer.class); - HttpClientSseClientTransport transport = new HttpClientSseClientTransport(httpClient, HttpRequest.newBuilder(), - "http://example.com", "http://example.com/sse", new ObjectMapper()); + // Create a transport with the customizer + var customizedTransport = HttpClientSseClientTransport.builder(host) + .httpRequestCustomizer(mockCustomizer) + .build(); + + // Connect + StepVerifier + .create(customizedTransport.connect(Function.identity()) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, context))) + .verifyComplete(); + + // Verify the customizer was called + verify(mockCustomizer).customize(any(), eq("GET"), + eq(UriComponentsBuilder.fromUriString(host).path("/sse").build().toUri()), isNull(), eq(context)); + clearInvocations(mockCustomizer); - transport.connect(Function.identity()); + // Send test message + var testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", "test-id", + Map.of("key", "value")); - ArgumentCaptor httpRequestCaptor = ArgumentCaptor.forClass(HttpRequest.class); - verify(httpClient).sendAsync(httpRequestCaptor.capture(), any(HttpResponse.BodyHandler.class)); - assertThat(httpRequestCaptor.getValue().uri()).isEqualTo(URI.create("http://example.com/sse")); + // Subscribe to messages and verify + StepVerifier + .create(customizedTransport.sendMessage(testMessage) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, context))) + .verifyComplete(); - transport.closeGracefully().block(); + // Verify the customizer was called + var uriArgumentCaptor = ArgumentCaptor.forClass(URI.class); + verify(mockCustomizer).customize(any(), eq("POST"), uriArgumentCaptor.capture(), eq( + "{\"jsonrpc\":\"2.0\",\"method\":\"test-method\",\"id\":\"test-id\",\"params\":{\"key\":\"value\"}}"), + eq(context)); + assertThat(uriArgumentCaptor.getValue().toString()).startsWith(host + "/message?sessionId="); + + // Clean up + customizedTransport.closeGracefully().block(); + } + + @Test + void testAsyncRequestCustomizer() { + var mockCustomizer = mock(McpAsyncHttpClientRequestCustomizer.class); + when(mockCustomizer.customize(any(), any(), any(), any(), any())) + .thenAnswer(invocation -> Mono.just(invocation.getArguments()[0])); + + // Create a transport with the customizer + var customizedTransport = HttpClientSseClientTransport.builder(host) + .asyncHttpRequestCustomizer(mockCustomizer) + .build(); + + // Connect + StepVerifier + .create(customizedTransport.connect(Function.identity()) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, context))) + .verifyComplete(); + + // Verify the customizer was called + verify(mockCustomizer).customize(any(), eq("GET"), + eq(UriComponentsBuilder.fromUriString(host).path("/sse").build().toUri()), isNull(), eq(context)); + clearInvocations(mockCustomizer); + + // Send test message + var testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", "test-id", + Map.of("key", "value")); + + // Subscribe to messages and verify + StepVerifier + .create(customizedTransport.sendMessage(testMessage) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, context))) + .verifyComplete(); + + // Verify the customizer was called + var uriArgumentCaptor = ArgumentCaptor.forClass(URI.class); + verify(mockCustomizer).customize(any(), eq("POST"), uriArgumentCaptor.capture(), eq( + "{\"jsonrpc\":\"2.0\",\"method\":\"test-method\",\"id\":\"test-id\",\"params\":{\"key\":\"value\"}}"), + eq(context)); + assertThat(uriArgumentCaptor.getValue().toString()).startsWith(host + "/message?sessionId="); + + // Clean up + customizedTransport.closeGracefully().block(); } } diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportEmptyJsonResponseTest.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportEmptyJsonResponseTest.java new file mode 100644 index 000000000..81e642681 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportEmptyJsonResponseTest.java @@ -0,0 +1,94 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.client.transport; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.URI; +import java.net.URISyntaxException; + +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import com.sun.net.httpserver.HttpServer; + +import io.modelcontextprotocol.client.transport.customizer.McpSyncHttpClientRequestCustomizer; +import io.modelcontextprotocol.server.transport.TomcatTestUtil; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.ProtocolVersions; +import reactor.test.StepVerifier; + +/** + * Handles emplty application/json response with 200 OK status code. + * + * @author codezkk + */ +public class HttpClientStreamableHttpTransportEmptyJsonResponseTest { + + static int PORT = TomcatTestUtil.findAvailablePort(); + + static String host = "http://localhost:" + PORT; + + static HttpServer server; + + @BeforeAll + static void startContainer() throws IOException { + + server = HttpServer.create(new InetSocketAddress(PORT), 0); + + // Empty, 200 OK response for the /mcp endpoint + server.createContext("/mcp", exchange -> { + exchange.getResponseHeaders().set("Content-Type", "application/json"); + exchange.sendResponseHeaders(200, 0); + exchange.close(); + }); + + server.setExecutor(null); + server.start(); + } + + @AfterAll + static void stopContainer() { + server.stop(1); + } + + /** + * Regardless of the response (even if the response is null and the content-type is + * present), notify should handle it correctly. + */ + @Test + @Timeout(3) + void testNotificationInitialized() throws URISyntaxException { + + var uri = new URI(host + "/mcp"); + var mockRequestCustomizer = mock(McpSyncHttpClientRequestCustomizer.class); + var transport = HttpClientStreamableHttpTransport.builder(host) + .httpRequestCustomizer(mockRequestCustomizer) + .build(); + + var initializeRequest = new McpSchema.InitializeRequest(ProtocolVersions.MCP_2025_03_26, + McpSchema.ClientCapabilities.builder().roots(true).build(), + new McpSchema.Implementation("MCP Client", "0.3.1")); + var testMessage = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, McpSchema.METHOD_INITIALIZE, + "test-id", initializeRequest); + + StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); + + // Verify the customizer was called + verify(mockRequestCustomizer, atLeastOnce()).customize(any(), eq("POST"), eq(uri), eq( + "{\"jsonrpc\":\"2.0\",\"method\":\"initialize\",\"id\":\"test-id\",\"params\":{\"protocolVersion\":\"2025-03-26\",\"capabilities\":{\"roots\":{\"listChanged\":true}},\"clientInfo\":{\"name\":\"MCP Client\",\"version\":\"0.3.1\"}}}"), + any()); + + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportErrorHandlingTest.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportErrorHandlingTest.java new file mode 100644 index 000000000..b82d6eb2c --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportErrorHandlingTest.java @@ -0,0 +1,345 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package io.modelcontextprotocol.client.transport; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import com.sun.net.httpserver.HttpServer; + +import io.modelcontextprotocol.server.transport.TomcatTestUtil; +import io.modelcontextprotocol.spec.HttpHeaders; +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpTransportException; +import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException; +import io.modelcontextprotocol.spec.ProtocolVersions; +import reactor.test.StepVerifier; + +/** + * Tests for error handling changes in HttpClientStreamableHttpTransport. Specifically + * tests the distinction between session-related errors and general transport errors for + * 404 and 400 status codes. + * + * @author Christian Tzolov + */ +@Timeout(15) +public class HttpClientStreamableHttpTransportErrorHandlingTest { + + private static final int PORT = TomcatTestUtil.findAvailablePort(); + + private static final String HOST = "http://localhost:" + PORT; + + private HttpServer server; + + private AtomicReference serverResponseStatus = new AtomicReference<>(200); + + private AtomicReference currentServerSessionId = new AtomicReference<>(null); + + private AtomicReference lastReceivedSessionId = new AtomicReference<>(null); + + private McpClientTransport transport; + + @BeforeEach + void startServer() throws IOException { + server = HttpServer.create(new InetSocketAddress(PORT), 0); + + // Configure the /mcp endpoint with dynamic response + server.createContext("/mcp", httpExchange -> { + if ("DELETE".equals(httpExchange.getRequestMethod())) { + httpExchange.sendResponseHeaders(200, 0); + } + else if ("POST".equals(httpExchange.getRequestMethod())) { + // Capture session ID from request if present + String requestSessionId = httpExchange.getRequestHeaders().getFirst(HttpHeaders.MCP_SESSION_ID); + lastReceivedSessionId.set(requestSessionId); + + int status = serverResponseStatus.get(); + + // Set response headers + httpExchange.getResponseHeaders().set("Content-Type", "application/json"); + + // Add session ID to response if configured + String responseSessionId = currentServerSessionId.get(); + if (responseSessionId != null) { + httpExchange.getResponseHeaders().set(HttpHeaders.MCP_SESSION_ID, responseSessionId); + } + + // Send response based on configured status + if (status == 200) { + String response = "{\"jsonrpc\":\"2.0\",\"result\":{},\"id\":\"test-id\"}"; + httpExchange.sendResponseHeaders(200, response.length()); + httpExchange.getResponseBody().write(response.getBytes()); + } + else { + httpExchange.sendResponseHeaders(status, 0); + } + } + httpExchange.close(); + }); + + server.setExecutor(null); + server.start(); + + transport = HttpClientStreamableHttpTransport.builder(HOST).build(); + } + + @AfterEach + void stopServer() { + if (server != null) { + server.stop(0); + } + } + + /** + * Test that 404 response WITHOUT session ID throws McpTransportException (not + * SessionNotFoundException) + */ + @Test + void test404WithoutSessionId() { + serverResponseStatus.set(404); + currentServerSessionId.set(null); // No session ID in response + + var testMessage = createTestRequestMessage(); + + StepVerifier.create(transport.sendMessage(testMessage)) + .expectErrorMatches(throwable -> throwable instanceof McpTransportException + && throwable.getMessage().contains("Not Found") && throwable.getMessage().contains("404") + && !(throwable instanceof McpTransportSessionNotFoundException)) + .verify(); + + StepVerifier.create(transport.closeGracefully()).verifyComplete(); + } + + /** + * Test that 404 response WITH session ID throws McpTransportSessionNotFoundException + */ + @Test + void test404WithSessionId() { + // First establish a session + serverResponseStatus.set(200); + currentServerSessionId.set("test-session-123"); + + // Set up exception handler to verify session invalidation + @SuppressWarnings("unchecked") + Consumer exceptionHandler = mock(Consumer.class); + transport.setExceptionHandler(exceptionHandler); + + // Connect with handler + StepVerifier.create(transport.connect(msg -> msg)).verifyComplete(); + + // Send initial message to establish session + var testMessage = createTestRequestMessage(); + StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); + + // The session should now be established, next request will include session ID + // Now return 404 for next request + serverResponseStatus.set(404); + + // Send another message - should get SessionNotFoundException + StepVerifier.create(transport.sendMessage(testMessage)) + .expectError(McpTransportSessionNotFoundException.class) + .verify(); + + // Verify exception handler was called with SessionNotFoundException + verify(exceptionHandler).accept(any(McpTransportSessionNotFoundException.class)); + + StepVerifier.create(transport.closeGracefully()).verifyComplete(); + } + + /** + * Test that 400 response WITHOUT session ID throws McpTransportException (not + * SessionNotFoundException) + */ + @Test + void test400WithoutSessionId() { + serverResponseStatus.set(400); + currentServerSessionId.set(null); // No session ID + + var testMessage = createTestRequestMessage(); + + StepVerifier.create(transport.sendMessage(testMessage)) + .expectErrorMatches(throwable -> throwable instanceof McpTransportException + && throwable.getMessage().contains("Bad Request") && throwable.getMessage().contains("400") + && !(throwable instanceof McpTransportSessionNotFoundException)) + .verify(); + + StepVerifier.create(transport.closeGracefully()).verifyComplete(); + } + + /** + * Test that 400 response WITH session ID throws McpTransportSessionNotFoundException + * This handles the case mentioned in the code comment about some implementations + * returning 400 for unknown session IDs. + */ + @Test + void test400WithSessionId() { + // First establish a session + serverResponseStatus.set(200); + currentServerSessionId.set("test-session-456"); + + // Set up exception handler + @SuppressWarnings("unchecked") + Consumer exceptionHandler = mock(Consumer.class); + transport.setExceptionHandler(exceptionHandler); + + // Connect with handler + StepVerifier.create(transport.connect(msg -> msg)).verifyComplete(); + + // Send initial message to establish session + var testMessage = createTestRequestMessage(); + StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); + + // The session should now be established, next request will include session ID + // Now return 400 for next request (simulating unknown session ID) + serverResponseStatus.set(400); + + // Send another message - should get SessionNotFoundException + StepVerifier.create(transport.sendMessage(testMessage)) + .expectError(McpTransportSessionNotFoundException.class) + .verify(); + + // Verify exception handler was called + verify(exceptionHandler).accept(any(McpTransportSessionNotFoundException.class)); + + StepVerifier.create(transport.closeGracefully()).verifyComplete(); + } + + /** + * Test session recovery after SessionNotFoundException Verifies that a new session + * can be established after the old one is invalidated + */ + @Test + void testSessionRecoveryAfter404() { + // First establish a session + serverResponseStatus.set(200); + currentServerSessionId.set("session-1"); + + // Send initial message to establish session + var testMessage = createTestRequestMessage(); + StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); + + assertThat(lastReceivedSessionId.get()).isNull(); + + // The session should now be established + // Simulate session loss - return 404 + serverResponseStatus.set(404); + + // This should fail with SessionNotFoundException + StepVerifier.create(transport.sendMessage(testMessage)) + .expectError(McpTransportSessionNotFoundException.class) + .verify(); + + // Now server is back with new session + serverResponseStatus.set(200); + currentServerSessionId.set("session-2"); + lastReceivedSessionId.set(null); // Reset to verify new session + + // Should be able to establish new session + StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); + + // Verify no session ID was sent (since old session was invalidated) + assertThat(lastReceivedSessionId.get()).isNull(); + + // Next request should use the new session ID + StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); + + // Session ID should now be sent with requests + assertThat(lastReceivedSessionId.get()).isEqualTo("session-2"); + + StepVerifier.create(transport.closeGracefully()).verifyComplete(); + } + + /** + * Test that reconnect (GET request) also properly handles 404/400 errors + */ + @Test + void testReconnectErrorHandling() { + + // Set up SSE endpoint for GET requests + server.createContext("/mcp-sse", exchange -> { + String method = exchange.getRequestMethod(); + String requestSessionId = exchange.getRequestHeaders().getFirst(HttpHeaders.MCP_SESSION_ID); + + if ("GET".equals(method)) { + int status = serverResponseStatus.get(); + + if (status == 404 && requestSessionId != null) { + // 404 with session ID - should trigger SessionNotFoundException + exchange.sendResponseHeaders(404, 0); + } + else if (status == 404) { + // 404 without session ID - should trigger McpTransportException + exchange.sendResponseHeaders(404, 0); + } + else { + // Normal SSE response + exchange.getResponseHeaders().set("Content-Type", "text/event-stream"); + exchange.sendResponseHeaders(200, 0); + // Send a test SSE event + String sseData = "event: message\ndata: {\"jsonrpc\":\"2.0\",\"method\":\"test\",\"params\":{}}\n\n"; + exchange.getResponseBody().write(sseData.getBytes()); + } + } + else { + // POST request handling + exchange.getResponseHeaders().set("Content-Type", "application/json"); + String responseSessionId = currentServerSessionId.get(); + if (responseSessionId != null) { + exchange.getResponseHeaders().set(HttpHeaders.MCP_SESSION_ID, responseSessionId); + } + String response = "{\"jsonrpc\":\"2.0\",\"result\":{},\"id\":\"test-id\"}"; + exchange.sendResponseHeaders(200, response.length()); + exchange.getResponseBody().write(response.getBytes()); + } + exchange.close(); + }); + + // Test with session ID - should get SessionNotFoundException + serverResponseStatus.set(200); + currentServerSessionId.set("sse-session-1"); + + var transport = HttpClientStreamableHttpTransport.builder(HOST) + .endpoint("/mcp-sse") + .openConnectionOnStartup(true) // This will trigger GET request on connect + .build(); + + // First connect successfully + StepVerifier.create(transport.connect(msg -> msg)).verifyComplete(); + + // Send message to establish session + var testMessage = createTestRequestMessage(); + StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); + + // Now simulate server returning 404 on reconnect + serverResponseStatus.set(404); + + // This should trigger reconnect which will fail + // The error should be handled internally and passed to exception handler + + StepVerifier.create(transport.closeGracefully()).verifyComplete(); + } + + private McpSchema.JSONRPCRequest createTestRequestMessage() { + var initializeRequest = new McpSchema.InitializeRequest(ProtocolVersions.MCP_2025_03_26, + McpSchema.ClientCapabilities.builder().roots(true).build(), + new McpSchema.Implementation("Test Client", "1.0.0")); + return new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, McpSchema.METHOD_INITIALIZE, "test-id", + initializeRequest); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportTest.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportTest.java new file mode 100644 index 000000000..2ade30e17 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportTest.java @@ -0,0 +1,163 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.client.transport; + +import io.modelcontextprotocol.client.transport.customizer.McpAsyncHttpClientRequestCustomizer; +import io.modelcontextprotocol.client.transport.customizer.McpSyncHttpClientRequestCustomizer; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.spec.McpSchema; +import java.net.URI; +import java.net.URISyntaxException; +import java.util.Map; +import java.util.function.Consumer; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.wait.strategy.Wait; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * Tests for the {@link HttpClientStreamableHttpTransport} class. + * + * @author Daniel Garnier-Moiroux + */ +class HttpClientStreamableHttpTransportTest { + + static String host = "http://localhost:3001"; + + private McpTransportContext context = McpTransportContext + .create(Map.of("test-transport-context-key", "some-value")); + + @SuppressWarnings("resource") + static GenericContainer container = new GenericContainer<>("docker.io/node:lts-alpine3.23") + .withCommand("npx -y @modelcontextprotocol/server-everything@2025.12.18 streamableHttp") + .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) + .withExposedPorts(3001) + .waitingFor(Wait.forHttp("/").forStatusCode(404)); + + @BeforeAll + static void startContainer() { + container.start(); + int port = container.getMappedPort(3001); + host = "http://" + container.getHost() + ":" + port; + } + + @AfterAll + static void stopContainer() { + container.stop(); + } + + void withTransport(HttpClientStreamableHttpTransport transport, Consumer c) { + try { + c.accept(transport); + } + finally { + StepVerifier.create(transport.closeGracefully()).verifyComplete(); + } + } + + @Test + void testRequestCustomizer() throws URISyntaxException { + var uri = new URI(host + "/mcp"); + var mockRequestCustomizer = mock(McpSyncHttpClientRequestCustomizer.class); + + var transport = HttpClientStreamableHttpTransport.builder(host) + .httpRequestCustomizer(mockRequestCustomizer) + .build(); + + withTransport(transport, (t) -> { + // Send test message + var initializeRequest = new McpSchema.InitializeRequest(McpSchema.LATEST_PROTOCOL_VERSION, + McpSchema.ClientCapabilities.builder().roots(true).build(), + new McpSchema.Implementation("MCP Client", "0.3.1")); + var testMessage = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, McpSchema.METHOD_INITIALIZE, + "test-id", initializeRequest); + + StepVerifier + .create(t.sendMessage(testMessage).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, context))) + .verifyComplete(); + + // Verify the customizer was called + verify(mockRequestCustomizer, atLeastOnce()).customize(any(), eq("POST"), eq(uri), eq( + "{\"jsonrpc\":\"2.0\",\"method\":\"initialize\",\"id\":\"test-id\",\"params\":{\"protocolVersion\":\"2025-06-18\",\"capabilities\":{\"roots\":{\"listChanged\":true}},\"clientInfo\":{\"name\":\"MCP Client\",\"version\":\"0.3.1\"}}}"), + eq(context)); + }); + } + + @Test + void testAsyncRequestCustomizer() throws URISyntaxException { + var uri = new URI(host + "/mcp"); + var mockRequestCustomizer = mock(McpAsyncHttpClientRequestCustomizer.class); + when(mockRequestCustomizer.customize(any(), any(), any(), any(), any())) + .thenAnswer(invocation -> Mono.just(invocation.getArguments()[0])); + + var transport = HttpClientStreamableHttpTransport.builder(host) + .asyncHttpRequestCustomizer(mockRequestCustomizer) + .build(); + + withTransport(transport, (t) -> { + // Send test message + var initializeRequest = new McpSchema.InitializeRequest(McpSchema.LATEST_PROTOCOL_VERSION, + McpSchema.ClientCapabilities.builder().roots(true).build(), + new McpSchema.Implementation("MCP Client", "0.3.1")); + var testMessage = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, McpSchema.METHOD_INITIALIZE, + "test-id", initializeRequest); + + StepVerifier + .create(t.sendMessage(testMessage).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, context))) + .verifyComplete(); + + // Verify the customizer was called + verify(mockRequestCustomizer, atLeastOnce()).customize(any(), eq("POST"), eq(uri), eq( + "{\"jsonrpc\":\"2.0\",\"method\":\"initialize\",\"id\":\"test-id\",\"params\":{\"protocolVersion\":\"2025-06-18\",\"capabilities\":{\"roots\":{\"listChanged\":true}},\"clientInfo\":{\"name\":\"MCP Client\",\"version\":\"0.3.1\"}}}"), + eq(context)); + }); + } + + @Test + void testCloseUninitialized() { + var transport = HttpClientStreamableHttpTransport.builder(host).build(); + + StepVerifier.create(transport.closeGracefully()).verifyComplete(); + + var initializeRequest = new McpSchema.InitializeRequest(McpSchema.LATEST_PROTOCOL_VERSION, + McpSchema.ClientCapabilities.builder().roots(true).build(), + new McpSchema.Implementation("MCP Client", "0.3.1")); + var testMessage = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, McpSchema.METHOD_INITIALIZE, + "test-id", initializeRequest); + + StepVerifier.create(transport.sendMessage(testMessage)) + .expectErrorMessage("MCP session has been closed") + .verify(); + } + + @Test + void testCloseInitialized() { + var transport = HttpClientStreamableHttpTransport.builder(host).build(); + + var initializeRequest = new McpSchema.InitializeRequest(McpSchema.LATEST_PROTOCOL_VERSION, + McpSchema.ClientCapabilities.builder().roots(true).build(), + new McpSchema.Implementation("MCP Client", "0.3.1")); + var testMessage = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, McpSchema.METHOD_INITIALIZE, + "test-id", initializeRequest); + + StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); + StepVerifier.create(transport.closeGracefully()).verifyComplete(); + + StepVerifier.create(transport.sendMessage(testMessage)) + .expectErrorMatches(err -> err.getMessage().matches("MCP session with ID [a-zA-Z0-9-]* has been closed")) + .verify(); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/customizer/DelegatingMcpAsyncHttpClientRequestCustomizerTest.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/customizer/DelegatingMcpAsyncHttpClientRequestCustomizerTest.java new file mode 100644 index 000000000..a04787aa3 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/customizer/DelegatingMcpAsyncHttpClientRequestCustomizerTest.java @@ -0,0 +1,72 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.client.transport.customizer; + +import java.net.URI; +import java.net.http.HttpRequest; +import java.util.List; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import io.modelcontextprotocol.common.McpTransportContext; + +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * Tests for {@link DelegatingMcpAsyncHttpClientRequestCustomizer}. + * + * @author Daniel Garnier-Moiroux + */ +class DelegatingMcpAsyncHttpClientRequestCustomizerTest { + + private static final URI TEST_URI = URI.create("https://example.com"); + + private final HttpRequest.Builder TEST_BUILDER = HttpRequest.newBuilder(TEST_URI); + + @Test + void delegates() { + var mockCustomizer = mock(McpAsyncHttpClientRequestCustomizer.class); + when(mockCustomizer.customize(any(), any(), any(), any(), any())) + .thenAnswer(invocation -> Mono.just(invocation.getArguments()[0])); + var customizer = new DelegatingMcpAsyncHttpClientRequestCustomizer(List.of(mockCustomizer)); + + var context = McpTransportContext.EMPTY; + StepVerifier + .create(customizer.customize(TEST_BUILDER, "GET", TEST_URI, "{\"everybody\": \"needs somebody\"}", context)) + .expectNext(TEST_BUILDER) + .verifyComplete(); + + verify(mockCustomizer).customize(TEST_BUILDER, "GET", TEST_URI, "{\"everybody\": \"needs somebody\"}", context); + } + + @Test + void delegatesInOrder() { + var customizer = new DelegatingMcpAsyncHttpClientRequestCustomizer( + List.of((builder, method, uri, body, ctx) -> Mono.just(builder.copy().header("x-test", "one")), + (builder, method, uri, body, ctx) -> Mono.just(builder.copy().header("x-test", "two")))); + + var headers = Mono + .from(customizer.customize(TEST_BUILDER, "GET", TEST_URI, "{\"everybody\": \"needs somebody\"}", + McpTransportContext.EMPTY)) + .map(HttpRequest.Builder::build) + .map(HttpRequest::headers) + .flatMapIterable(h -> h.allValues("x-test")); + + StepVerifier.create(headers).expectNext("one").expectNext("two").verifyComplete(); + } + + @Test + void constructorRequiresNonNull() { + assertThatThrownBy(() -> new DelegatingMcpAsyncHttpClientRequestCustomizer(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Customizers must not be null"); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/customizer/DelegatingMcpSyncHttpClientRequestCustomizerTest.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/customizer/DelegatingMcpSyncHttpClientRequestCustomizerTest.java new file mode 100644 index 000000000..6c51a3d12 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/customizer/DelegatingMcpSyncHttpClientRequestCustomizerTest.java @@ -0,0 +1,61 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.client.transport.customizer; + +import java.net.URI; +import java.net.http.HttpRequest; +import java.util.List; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; + +import io.modelcontextprotocol.common.McpTransportContext; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.verify; + +/** + * Tests for {@link DelegatingMcpSyncHttpClientRequestCustomizer}. + * + * @author Daniel Garnier-Moiroux + */ +class DelegatingMcpSyncHttpClientRequestCustomizerTest { + + private static final URI TEST_URI = URI.create("https://example.com"); + + private final HttpRequest.Builder TEST_BUILDER = HttpRequest.newBuilder(TEST_URI); + + @Test + void delegates() { + var mockCustomizer = Mockito.mock(McpSyncHttpClientRequestCustomizer.class); + var customizer = new DelegatingMcpSyncHttpClientRequestCustomizer(List.of(mockCustomizer)); + + var context = McpTransportContext.EMPTY; + customizer.customize(TEST_BUILDER, "GET", TEST_URI, "{\"everybody\": \"needs somebody\"}", context); + + verify(mockCustomizer).customize(TEST_BUILDER, "GET", TEST_URI, "{\"everybody\": \"needs somebody\"}", context); + } + + @Test + void delegatesInOrder() { + var testHeaderName = "x-test"; + var customizer = new DelegatingMcpSyncHttpClientRequestCustomizer( + List.of((builder, method, uri, body, ctx) -> builder.header(testHeaderName, "one"), + (builder, method, uri, body, ctx) -> builder.header(testHeaderName, "two"))); + + customizer.customize(TEST_BUILDER, "GET", TEST_URI, null, McpTransportContext.EMPTY); + var request = TEST_BUILDER.build(); + + assertThat(request.headers().allValues(testHeaderName)).containsExactly("one", "two"); + } + + @Test + void constructorRequiresNonNull() { + assertThatThrownBy(() -> new DelegatingMcpAsyncHttpClientRequestCustomizer(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Customizers must not be null"); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/common/AsyncServerMcpTransportContextIntegrationTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/common/AsyncServerMcpTransportContextIntegrationTests.java new file mode 100644 index 000000000..8b2dea462 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/common/AsyncServerMcpTransportContextIntegrationTests.java @@ -0,0 +1,280 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.common; + +import java.util.Map; +import java.util.function.BiFunction; + +import io.modelcontextprotocol.client.McpAsyncClient; +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.client.transport.customizer.McpAsyncHttpClientRequestCustomizer; +import io.modelcontextprotocol.client.transport.customizer.McpSyncHttpClientRequestCustomizer; +import io.modelcontextprotocol.server.McpAsyncServerExchange; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.McpServerFeatures; +import io.modelcontextprotocol.server.McpStatelessServerFeatures; +import io.modelcontextprotocol.server.McpTransportContextExtractor; +import io.modelcontextprotocol.server.transport.HttpServletSseServerTransportProvider; +import io.modelcontextprotocol.server.transport.HttpServletStatelessServerTransport; +import io.modelcontextprotocol.server.transport.HttpServletStreamableServerTransportProvider; +import io.modelcontextprotocol.server.transport.TomcatTestUtil; +import io.modelcontextprotocol.spec.McpSchema; +import jakarta.servlet.Servlet; +import jakarta.servlet.http.HttpServletRequest; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.apache.catalina.startup.Tomcat; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for {@link McpTransportContext} propagation between MCP clients and + * async servers. + * + *

+ * This test class validates the end-to-end flow of transport context propagation in MCP + * communication, demonstrating how contextual information can be passed from client to + * server through HTTP headers and accessed within server-side handlers. + * + *

Test Scenarios

+ *

+ * The tests cover multiple transport configurations with async servers: + *

    + *
  • Stateless server with async streamable HTTP clients
  • + *
  • Streamable server with async streamable HTTP clients
  • + *
  • SSE (Server-Sent Events) server with async SSE clients
  • + *
+ * + *

Context Propagation Flow

+ *
    + *
  1. Client-side: Context data is stored in the Reactor Context and injected into HTTP + * headers via {@link McpSyncHttpClientRequestCustomizer}
  2. + *
  3. Transport: The context travels as HTTP headers (specifically "x-test" header in + * these tests)
  4. + *
  5. Server-side: A {@link McpTransportContextExtractor} extracts the header value and + * makes it available to request handlers through {@link McpTransportContext}
  6. + *
  7. Verification: The server echoes back the received context value as the tool call + * result
  8. + *
+ * + *

+ * All tests use an embedded Tomcat server running on a dynamically allocated port to + * ensure isolation and prevent port conflicts during parallel test execution. + * + * @author Daniel Garnier-Moiroux + * @author Christian Tzolov + */ +@Timeout(15) +public class AsyncServerMcpTransportContextIntegrationTests { + + private static final int PORT = TomcatTestUtil.findAvailablePort(); + + private Tomcat tomcat; + + private static final String HEADER_NAME = "x-test"; + + private final McpAsyncHttpClientRequestCustomizer asyncClientRequestCustomizer = (builder, method, endpoint, body, + context) -> { + var headerValue = context.get("client-side-header-value"); + if (headerValue != null) { + builder.header(HEADER_NAME, headerValue.toString()); + } + return Mono.just(builder); + }; + + private final McpTransportContextExtractor serverContextExtractor = (HttpServletRequest r) -> { + var headerValue = r.getHeader(HEADER_NAME); + return headerValue != null ? McpTransportContext.create(Map.of("server-side-header-value", headerValue)) + : McpTransportContext.EMPTY; + }; + + private final HttpServletStatelessServerTransport statelessServerTransport = HttpServletStatelessServerTransport + .builder() + .contextExtractor(serverContextExtractor) + .build(); + + private final HttpServletStreamableServerTransportProvider streamableServerTransport = HttpServletStreamableServerTransportProvider + .builder() + .contextExtractor(serverContextExtractor) + .build(); + + private final HttpServletSseServerTransportProvider sseServerTransport = HttpServletSseServerTransportProvider + .builder() + .contextExtractor(serverContextExtractor) + .messageEndpoint("/message") + .build(); + + private final McpAsyncClient asyncStreamableClient = McpClient + .async(HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT) + .asyncHttpRequestCustomizer(asyncClientRequestCustomizer) + .build()) + .build(); + + private final McpAsyncClient asyncSseClient = McpClient + .async(HttpClientSseClientTransport.builder("http://localhost:" + PORT) + .asyncHttpRequestCustomizer(asyncClientRequestCustomizer) + .build()) + .build(); + + private final McpSchema.Tool tool = McpSchema.Tool.builder() + .name("test-tool") + .description("return the value of the x-test header from call tool request") + .build(); + + private final BiFunction> asyncStatelessHandler = ( + transportContext, request) -> { + return Mono + .just(new McpSchema.CallToolResult(transportContext.get("server-side-header-value").toString(), null)); + }; + + private final BiFunction> asyncStatefulHandler = ( + exchange, request) -> { + return asyncStatelessHandler.apply(exchange.transportContext(), request); + }; + + @AfterEach + public void after() { + if (statelessServerTransport != null) { + statelessServerTransport.closeGracefully().block(); + } + if (streamableServerTransport != null) { + streamableServerTransport.closeGracefully().block(); + } + if (sseServerTransport != null) { + sseServerTransport.closeGracefully().block(); + } + if (asyncStreamableClient != null) { + asyncStreamableClient.closeGracefully().block(); + } + if (asyncSseClient != null) { + asyncSseClient.closeGracefully().block(); + } + stopTomcat(); + } + + @Test + void asyncClinetStatelessServer() { + startTomcat(statelessServerTransport); + + var mcpServer = McpServer.async(statelessServerTransport) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpStatelessServerFeatures.AsyncToolSpecification(tool, asyncStatelessHandler)) + .build(); + + StepVerifier.create(asyncStreamableClient.initialize()).assertNext(initResult -> { + assertThat(initResult).isNotNull(); + }).verifyComplete(); + + // Test tool call with context + StepVerifier + .create(asyncStreamableClient.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, + McpTransportContext.create(Map.of("client-side-header-value", "some important value"))))) + .assertNext(response -> { + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + }) + .verifyComplete(); + + mcpServer.close(); + } + + @Test + void asyncClientStreamableServer() { + startTomcat(streamableServerTransport); + + var mcpServer = McpServer.async(streamableServerTransport) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpServerFeatures.AsyncToolSpecification(tool, null, asyncStatefulHandler)) + .build(); + + StepVerifier.create(asyncStreamableClient.initialize()).assertNext(initResult -> { + assertThat(initResult).isNotNull(); + }).verifyComplete(); + + // Test tool call with context + StepVerifier + .create(asyncStreamableClient.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, + McpTransportContext.create(Map.of("client-side-header-value", "some important value"))))) + .assertNext(response -> { + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + }) + .verifyComplete(); + + mcpServer.close(); + } + + @Test + void asyncClientSseServer() { + startTomcat(sseServerTransport); + + var mcpServer = McpServer.async(sseServerTransport) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpServerFeatures.AsyncToolSpecification(tool, null, asyncStatefulHandler)) + .build(); + + StepVerifier.create(asyncSseClient.initialize()).assertNext(initResult -> { + assertThat(initResult).isNotNull(); + }).verifyComplete(); + + // Test tool call with context + StepVerifier + .create(asyncSseClient.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, + McpTransportContext.create(Map.of("client-side-header-value", "some important value"))))) + .assertNext(response -> { + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + }) + .verifyComplete(); + + mcpServer.close(); + } + + private void startTomcat(Servlet transport) { + tomcat = TomcatTestUtil.createTomcatServer("", PORT, transport); + try { + tomcat.start(); + assertThat(tomcat.getServer().getState()).isEqualTo(LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + } + + private void stopTomcat() { + if (tomcat != null) { + try { + tomcat.stop(); + tomcat.destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/common/HttpClientStreamableHttpVersionNegotiationIntegrationTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/common/HttpClientStreamableHttpVersionNegotiationIntegrationTests.java new file mode 100644 index 000000000..8efb6a960 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/common/HttpClientStreamableHttpVersionNegotiationIntegrationTests.java @@ -0,0 +1,145 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package io.modelcontextprotocol.common; + +import java.util.List; +import java.util.Map; +import java.util.function.BiFunction; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.McpServerFeatures; +import io.modelcontextprotocol.server.McpSyncServer; +import io.modelcontextprotocol.server.McpSyncServerExchange; +import io.modelcontextprotocol.server.transport.HttpServletStreamableServerTransportProvider; +import io.modelcontextprotocol.server.transport.McpTestRequestRecordingServletFilter; +import io.modelcontextprotocol.server.transport.TomcatTestUtil; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.ProtocolVersions; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.apache.catalina.startup.Tomcat; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +class HttpClientStreamableHttpVersionNegotiationIntegrationTests { + + private Tomcat tomcat; + + private static final int PORT = TomcatTestUtil.findAvailablePort(); + + private final McpTestRequestRecordingServletFilter requestRecordingFilter = new McpTestRequestRecordingServletFilter(); + + private final HttpServletStreamableServerTransportProvider transport = HttpServletStreamableServerTransportProvider + .builder() + .contextExtractor( + req -> McpTransportContext.create(Map.of("protocol-version", req.getHeader("MCP-protocol-version")))) + .build(); + + private final McpSchema.Tool toolSpec = McpSchema.Tool.builder() + .name("test-tool") + .description("return the protocol version used") + .build(); + + private final BiFunction toolHandler = ( + exchange, request) -> new McpSchema.CallToolResult( + exchange.transportContext().get("protocol-version").toString(), null); + + McpSyncServer mcpServer = McpServer.sync(transport) + .capabilities(McpSchema.ServerCapabilities.builder().tools(false).build()) + .tools(new McpServerFeatures.SyncToolSpecification(toolSpec, null, toolHandler)) + .build(); + + @AfterEach + void tearDown() { + stopTomcat(); + } + + @Test + void usesLatestVersion() { + startTomcat(); + + var client = McpClient.sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT).build()) + .build(); + + client.initialize(); + McpSchema.CallToolResult response = client.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())); + + var calls = requestRecordingFilter.getCalls(); + + assertThat(calls).filteredOn(c -> !c.body().contains("\"method\":\"initialize\"")) + // GET /mcp ; POST notification/initialized ; POST tools/call + .hasSize(3) + .map(McpTestRequestRecordingServletFilter.Call::headers) + .allSatisfy(headers -> assertThat(headers).containsEntry("mcp-protocol-version", + ProtocolVersions.MCP_2025_06_18)); + + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo(ProtocolVersions.MCP_2025_06_18); + mcpServer.close(); + } + + @Test + void usesServerSupportedVersion() { + startTomcat(); + + var transport = HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT) + .supportedProtocolVersions(List.of(ProtocolVersions.MCP_2025_06_18, "2263-03-18")) + .build(); + var client = McpClient.sync(transport).build(); + + client.initialize(); + McpSchema.CallToolResult response = client.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())); + + var calls = requestRecordingFilter.getCalls(); + // Initialize tells the server the Client's latest supported version + // FIXME: Set the correct protocol version on GET /mcp + assertThat(calls).filteredOn(c -> c.method().equals("POST") && !c.body().contains("\"method\":\"initialize\"")) + // POST notification/initialized ; POST tools/call + .hasSize(2) + .map(McpTestRequestRecordingServletFilter.Call::headers) + .allSatisfy(headers -> assertThat(headers).containsEntry("mcp-protocol-version", + ProtocolVersions.MCP_2025_06_18)); + + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo(ProtocolVersions.MCP_2025_06_18); + mcpServer.close(); + } + + private void startTomcat() { + tomcat = TomcatTestUtil.createTomcatServer("", PORT, transport, requestRecordingFilter); + try { + tomcat.start(); + assertThat(tomcat.getServer().getState()).isEqualTo(LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + } + + private void stopTomcat() { + if (tomcat != null) { + try { + tomcat.stop(); + tomcat.destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/common/SyncServerMcpTransportContextIntegrationTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/common/SyncServerMcpTransportContextIntegrationTests.java new file mode 100644 index 000000000..cc8f4c4be --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/common/SyncServerMcpTransportContextIntegrationTests.java @@ -0,0 +1,243 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.common; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.McpClient.SyncSpec; +import io.modelcontextprotocol.client.McpSyncClient; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.client.transport.customizer.McpSyncHttpClientRequestCustomizer; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.McpServerFeatures; +import io.modelcontextprotocol.server.McpStatelessServerFeatures; +import io.modelcontextprotocol.server.McpSyncServerExchange; +import io.modelcontextprotocol.server.McpTransportContextExtractor; +import io.modelcontextprotocol.server.transport.HttpServletSseServerTransportProvider; +import io.modelcontextprotocol.server.transport.HttpServletStatelessServerTransport; +import io.modelcontextprotocol.server.transport.HttpServletStreamableServerTransportProvider; +import io.modelcontextprotocol.server.transport.TomcatTestUtil; +import io.modelcontextprotocol.spec.McpSchema; +import jakarta.servlet.Servlet; +import jakarta.servlet.http.HttpServletRequest; +import java.util.Map; +import java.util.function.BiFunction; +import java.util.function.Supplier; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.apache.catalina.startup.Tomcat; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Test both Client and Server {@link McpTransportContext} integration, in two steps. + *

+ * First, the client calls a tool and writes data stored in a thread-local to an HTTP + * header using {@link SyncSpec#transportContextProvider(Supplier)} and + * {@link McpSyncHttpClientRequestCustomizer}. + *

+ * Then the server reads the header with a {@link McpTransportContextExtractor} and + * returns the value as the result of the tool call. + * + * @author Daniel Garnier-Moiroux + */ +@Timeout(15) +public class SyncServerMcpTransportContextIntegrationTests { + + private static final int PORT = TomcatTestUtil.findAvailablePort(); + + private Tomcat tomcat; + + private static final ThreadLocal CLIENT_SIDE_HEADER_VALUE_HOLDER = new ThreadLocal<>(); + + private static final String HEADER_NAME = "x-test"; + + private final Supplier clientContextProvider = () -> { + var headerValue = CLIENT_SIDE_HEADER_VALUE_HOLDER.get(); + return headerValue != null ? McpTransportContext.create(Map.of("client-side-header-value", headerValue)) + : McpTransportContext.EMPTY; + }; + + private final McpSyncHttpClientRequestCustomizer clientRequestCustomizer = (builder, method, endpoint, body, + context) -> { + var headerValue = context.get("client-side-header-value"); + if (headerValue != null) { + builder.header(HEADER_NAME, headerValue.toString()); + } + }; + + private final McpTransportContextExtractor serverContextExtractor = (HttpServletRequest r) -> { + var headerValue = r.getHeader(HEADER_NAME); + return headerValue != null ? McpTransportContext.create(Map.of("server-side-header-value", headerValue)) + : McpTransportContext.EMPTY; + }; + + private final BiFunction statelessHandler = ( + transportContext, + request) -> new McpSchema.CallToolResult(transportContext.get("server-side-header-value").toString(), null); + + private final BiFunction statefulHandler = ( + exchange, request) -> statelessHandler.apply(exchange.transportContext(), request); + + private final HttpServletStatelessServerTransport statelessServerTransport = HttpServletStatelessServerTransport + .builder() + .contextExtractor(serverContextExtractor) + .build(); + + private final HttpServletStreamableServerTransportProvider streamableServerTransport = HttpServletStreamableServerTransportProvider + .builder() + .contextExtractor(serverContextExtractor) + .build(); + + private final HttpServletSseServerTransportProvider sseServerTransport = HttpServletSseServerTransportProvider + .builder() + .contextExtractor(serverContextExtractor) + .messageEndpoint("/message") + .build(); + + private final McpSyncClient streamableClient = McpClient + .sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT) + .httpRequestCustomizer(clientRequestCustomizer) + .build()) + .transportContextProvider(clientContextProvider) + .build(); + + private final McpSyncClient sseClient = McpClient + .sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT) + .httpRequestCustomizer(clientRequestCustomizer) + .build()) + .transportContextProvider(clientContextProvider) + .build(); + + private final McpSchema.Tool tool = McpSchema.Tool.builder() + .name("test-tool") + .description("return the value of the x-test header from call tool request") + .build(); + + @AfterEach + public void after() { + CLIENT_SIDE_HEADER_VALUE_HOLDER.remove(); + if (statelessServerTransport != null) { + statelessServerTransport.closeGracefully().block(); + } + if (streamableServerTransport != null) { + streamableServerTransport.closeGracefully().block(); + } + if (sseServerTransport != null) { + sseServerTransport.closeGracefully().block(); + } + if (streamableClient != null) { + streamableClient.closeGracefully(); + } + if (sseClient != null) { + sseClient.closeGracefully(); + } + stopTomcat(); + } + + @Test + void statelessServer() { + startTomcat(statelessServerTransport); + + var mcpServer = McpServer.sync(statelessServerTransport) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpStatelessServerFeatures.SyncToolSpecification(tool, statelessHandler)) + .build(); + + McpSchema.InitializeResult initResult = streamableClient.initialize(); + assertThat(initResult).isNotNull(); + + CLIENT_SIDE_HEADER_VALUE_HOLDER.set("some important value"); + McpSchema.CallToolResult response = streamableClient + .callTool(new McpSchema.CallToolRequest("test-tool", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + + mcpServer.close(); + } + + @Test + void streamableServer() { + startTomcat(streamableServerTransport); + + var mcpServer = McpServer.sync(streamableServerTransport) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpServerFeatures.SyncToolSpecification(tool, null, statefulHandler)) + .build(); + + McpSchema.InitializeResult initResult = streamableClient.initialize(); + assertThat(initResult).isNotNull(); + + CLIENT_SIDE_HEADER_VALUE_HOLDER.set("some important value"); + McpSchema.CallToolResult response = streamableClient + .callTool(new McpSchema.CallToolRequest("test-tool", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + + mcpServer.close(); + } + + @Test + void sseServer() { + startTomcat(sseServerTransport); + + var mcpServer = McpServer.sync(sseServerTransport) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpServerFeatures.SyncToolSpecification(tool, null, statefulHandler)) + .build(); + + McpSchema.InitializeResult initResult = sseClient.initialize(); + assertThat(initResult).isNotNull(); + + CLIENT_SIDE_HEADER_VALUE_HOLDER.set("some important value"); + McpSchema.CallToolResult response = sseClient.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + + mcpServer.close(); + } + + private void startTomcat(Servlet transport) { + tomcat = TomcatTestUtil.createTomcatServer("", PORT, transport); + try { + tomcat.start(); + assertThat(tomcat.getServer().getState()).isEqualTo(LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + } + + private void stopTomcat() { + if (tomcat != null) { + try { + tomcat.stop(); + tomcat.destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java new file mode 100644 index 000000000..090710248 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java @@ -0,0 +1,722 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import java.time.Duration; +import java.util.List; + +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; +import io.modelcontextprotocol.spec.McpSchema.Prompt; +import io.modelcontextprotocol.spec.McpSchema.PromptMessage; +import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; +import io.modelcontextprotocol.spec.McpSchema.Resource; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import static io.modelcontextprotocol.util.ToolsUtils.EMPTY_JSON_SCHEMA; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Test suite for the {@link McpAsyncServer} that can be used with different + * {@link io.modelcontextprotocol.spec.McpServerTransportProvider} implementations. + * + * @author Christian Tzolov + */ +// KEEP IN SYNC with the class in mcp-test module +public abstract class AbstractMcpAsyncServerTests { + + private static final String TEST_TOOL_NAME = "test-tool"; + + private static final String TEST_RESOURCE_URI = "test://resource"; + + private static final String TEST_PROMPT_NAME = "test-prompt"; + + abstract protected McpServer.AsyncSpecification prepareAsyncServerBuilder(); + + protected void onStart() { + } + + protected void onClose() { + } + + @BeforeEach + void setUp() { + } + + @AfterEach + void tearDown() { + onClose(); + } + + // --------------------------------------- + // Server Lifecycle Tests + // --------------------------------------- + void testConstructorWithInvalidArguments() { + assertThatThrownBy(() -> McpServer.async((McpServerTransportProvider) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Transport provider must not be null"); + + assertThatThrownBy(() -> prepareAsyncServerBuilder().serverInfo((McpSchema.Implementation) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Server info must not be null"); + } + + @Test + void testGracefulShutdown() { + McpServer.AsyncSpecification builder = prepareAsyncServerBuilder(); + var mcpAsyncServer = builder.serverInfo("test-server", "1.0.0").build(); + + StepVerifier.create(mcpAsyncServer.closeGracefully()).verifyComplete(); + } + + @Test + void testImmediateClose() { + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + assertThatCode(mcpAsyncServer::close).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Tools Tests + // --------------------------------------- + @Test + @Deprecated + void testAddTool() { + Tool newTool = McpSchema.Tool.builder() + .name("new-tool") + .title("New test tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .build(); + + StepVerifier + .create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolSpecification(newTool, + (exchange, args) -> Mono.just(CallToolResult.builder().content(List.of()).isError(false).build())))) + .verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testAddToolCall() { + Tool newTool = McpSchema.Tool.builder() + .name("new-tool") + .title("New test tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .build(); + + StepVerifier.create(mcpAsyncServer.addTool(McpServerFeatures.AsyncToolSpecification.builder() + .tool(newTool) + .callHandler((exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) + .build())).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + @Deprecated + void testAddDuplicateTool() { + Tool duplicateTool = McpSchema.Tool.builder() + .name(TEST_TOOL_NAME) + .title("Duplicate tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tool(duplicateTool, + (exchange, args) -> Mono.just(CallToolResult.builder().content(List.of()).isError(false).build())) + .build(); + + StepVerifier + .create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolSpecification(duplicateTool, + (exchange, args) -> Mono.just(CallToolResult.builder().content(List.of()).isError(false).build())))) + .verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testAddDuplicateToolCall() { + Tool duplicateTool = McpSchema.Tool.builder() + .name(TEST_TOOL_NAME) + .title("Duplicate tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .toolCall(duplicateTool, + (exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) + .build(); + + StepVerifier.create(mcpAsyncServer.addTool(McpServerFeatures.AsyncToolSpecification.builder() + .tool(duplicateTool) + .callHandler((exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) + .build())).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testDuplicateToolCallDuringBuilding() { + Tool duplicateTool = McpSchema.Tool.builder() + .name("duplicate-build-toolcall") + .title("Duplicate toolcall during building") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + + assertThatThrownBy(() -> prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .toolCall(duplicateTool, + (exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) + .toolCall(duplicateTool, + (exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) // Duplicate! + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessage("Tool with name 'duplicate-build-toolcall' is already registered."); + } + + @Test + void testDuplicateToolsInBatchListRegistration() { + Tool duplicateTool = McpSchema.Tool.builder() + .name("batch-list-tool") + .title("Duplicate tool in batch list") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + + List specs = List.of( + McpServerFeatures.AsyncToolSpecification.builder() + .tool(duplicateTool) + .callHandler((exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) + .build(), + McpServerFeatures.AsyncToolSpecification.builder() + .tool(duplicateTool) + .callHandler((exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) + .build() // Duplicate! + ); + + assertThatThrownBy(() -> prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(specs) + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessage("Tool with name 'batch-list-tool' is already registered."); + } + + @Test + void testDuplicateToolsInBatchVarargsRegistration() { + Tool duplicateTool = McpSchema.Tool.builder() + .name("batch-varargs-tool") + .title("Duplicate tool in batch varargs") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + + assertThatThrownBy(() -> prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(McpServerFeatures.AsyncToolSpecification.builder() + .tool(duplicateTool) + .callHandler((exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) + .build(), + McpServerFeatures.AsyncToolSpecification.builder() + .tool(duplicateTool) + .callHandler((exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) + .build() // Duplicate! + ) + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessage("Tool with name 'batch-varargs-tool' is already registered."); + } + + @Test + void testRemoveTool() { + Tool too = McpSchema.Tool.builder() + .name(TEST_TOOL_NAME) + .title("Duplicate tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .toolCall(too, + (exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) + .build(); + + StepVerifier.create(mcpAsyncServer.removeTool(TEST_TOOL_NAME)).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testRemoveNonexistentTool() { + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .build(); + + StepVerifier.create(mcpAsyncServer.removeTool("nonexistent-tool")).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testNotifyToolsListChanged() { + Tool too = McpSchema.Tool.builder() + .name(TEST_TOOL_NAME) + .title("Duplicate tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .toolCall(too, + (exchange, args) -> Mono.just(CallToolResult.builder().content(List.of()).isError(false).build())) + .build(); + + StepVerifier.create(mcpAsyncServer.notifyToolsListChanged()).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Resources Tests + // --------------------------------------- + + @Test + void testNotifyResourcesListChanged() { + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + StepVerifier.create(mcpAsyncServer.notifyResourcesListChanged()).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testNotifyResourcesUpdated() { + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + StepVerifier + .create(mcpAsyncServer + .notifyResourcesUpdated(new McpSchema.ResourcesUpdatedNotification(TEST_RESOURCE_URI))) + .verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testAddResource() { + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + Resource resource = Resource.builder() + .uri(TEST_RESOURCE_URI) + .name("Test Resource") + .title("Test Resource") + .mimeType("text/plain") + .description("Test resource description") + .build(); + McpServerFeatures.AsyncResourceSpecification specification = new McpServerFeatures.AsyncResourceSpecification( + resource, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + StepVerifier.create(mcpAsyncServer.addResource(specification)).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testAddResourceWithNullSpecification() { + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + StepVerifier.create(mcpAsyncServer.addResource((McpServerFeatures.AsyncResourceSpecification) null)) + .verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(IllegalArgumentException.class).hasMessage("Resource must not be null"); + }); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testAddResourceWithoutCapability() { + // Create a server without resource capabilities + McpAsyncServer serverWithoutResources = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + Resource resource = Resource.builder() + .uri(TEST_RESOURCE_URI) + .name("Test Resource") + .title("Test Resource") + .mimeType("text/plain") + .description("Test resource description") + .build(); + McpServerFeatures.AsyncResourceSpecification specification = new McpServerFeatures.AsyncResourceSpecification( + resource, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + StepVerifier.create(serverWithoutResources.addResource(specification)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Server must be configured with resource capabilities"); + }); + } + + @Test + void testRemoveResourceWithoutCapability() { + // Create a server without resource capabilities + McpAsyncServer serverWithoutResources = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + StepVerifier.create(serverWithoutResources.removeResource(TEST_RESOURCE_URI)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Server must be configured with resource capabilities"); + }); + } + + @Test + void testListResources() { + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + Resource resource = Resource.builder() + .uri(TEST_RESOURCE_URI) + .name("Test Resource") + .title("Test Resource") + .mimeType("text/plain") + .description("Test resource description") + .build(); + McpServerFeatures.AsyncResourceSpecification specification = new McpServerFeatures.AsyncResourceSpecification( + resource, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + StepVerifier + .create(mcpAsyncServer.addResource(specification).then(mcpAsyncServer.listResources().collectList())) + .expectNextMatches(resources -> resources.size() == 1 && resources.get(0).uri().equals(TEST_RESOURCE_URI)) + .verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testRemoveResource() { + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + Resource resource = Resource.builder() + .uri(TEST_RESOURCE_URI) + .name("Test Resource") + .title("Test Resource") + .mimeType("text/plain") + .description("Test resource description") + .build(); + McpServerFeatures.AsyncResourceSpecification specification = new McpServerFeatures.AsyncResourceSpecification( + resource, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + StepVerifier + .create(mcpAsyncServer.addResource(specification).then(mcpAsyncServer.removeResource(TEST_RESOURCE_URI))) + .verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testRemoveNonexistentResource() { + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + // Removing a non-existent resource should complete successfully (no error) + // as per the new implementation that just logs a warning + StepVerifier.create(mcpAsyncServer.removeResource("nonexistent://resource")).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Resource Template Tests + // --------------------------------------- + + @Test + void testAddResourceTemplate() { + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + McpSchema.ResourceTemplate template = McpSchema.ResourceTemplate.builder() + .uriTemplate("test://template/{id}") + .name("test-template") + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.AsyncResourceTemplateSpecification specification = new McpServerFeatures.AsyncResourceTemplateSpecification( + template, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + StepVerifier.create(mcpAsyncServer.addResourceTemplate(specification)).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testAddResourceTemplateWithoutCapability() { + // Create a server without resource capabilities + McpAsyncServer serverWithoutResources = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + McpSchema.ResourceTemplate template = McpSchema.ResourceTemplate.builder() + .uriTemplate("test://template/{id}") + .name("test-template") + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.AsyncResourceTemplateSpecification specification = new McpServerFeatures.AsyncResourceTemplateSpecification( + template, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + StepVerifier.create(serverWithoutResources.addResourceTemplate(specification)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Server must be configured with resource capabilities"); + }); + } + + @Test + void testRemoveResourceTemplate() { + McpSchema.ResourceTemplate template = McpSchema.ResourceTemplate.builder() + .uriTemplate("test://template/{id}") + .name("test-template") + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.AsyncResourceTemplateSpecification specification = new McpServerFeatures.AsyncResourceTemplateSpecification( + template, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .resourceTemplates(specification) + .build(); + + StepVerifier.create(mcpAsyncServer.removeResourceTemplate("test://template/{id}")).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testRemoveResourceTemplateWithoutCapability() { + // Create a server without resource capabilities + McpAsyncServer serverWithoutResources = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + StepVerifier.create(serverWithoutResources.removeResourceTemplate("test://template/{id}")) + .verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Server must be configured with resource capabilities"); + }); + } + + @Test + void testRemoveNonexistentResourceTemplate() { + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + StepVerifier.create(mcpAsyncServer.removeResourceTemplate("nonexistent://template/{id}")).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testListResourceTemplates() { + McpSchema.ResourceTemplate template = McpSchema.ResourceTemplate.builder() + .uriTemplate("test://template/{id}") + .name("test-template") + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.AsyncResourceTemplateSpecification specification = new McpServerFeatures.AsyncResourceTemplateSpecification( + template, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .resourceTemplates(specification) + .build(); + + // Note: Based on the current implementation, listResourceTemplates() returns + // Flux + // This appears to be a bug in the implementation that should return + // Flux + StepVerifier.create(mcpAsyncServer.listResourceTemplates().collectList()) + .expectNextMatches(resources -> resources.size() >= 0) // Just verify it + // doesn't error + .verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Prompts Tests + // --------------------------------------- + + @Test + void testNotifyPromptsListChanged() { + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + StepVerifier.create(mcpAsyncServer.notifyPromptsListChanged()).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testAddPromptWithNullSpecification() { + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().prompts(false).build()) + .build(); + + StepVerifier.create(mcpAsyncServer.addPrompt((McpServerFeatures.AsyncPromptSpecification) null)) + .verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(IllegalArgumentException.class) + .hasMessage("Prompt specification must not be null"); + }); + } + + @Test + void testAddPromptWithoutCapability() { + // Create a server without prompt capabilities + McpAsyncServer serverWithoutPrompts = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", "Test Prompt", List.of()); + McpServerFeatures.AsyncPromptSpecification specification = new McpServerFeatures.AsyncPromptSpecification( + prompt, (exchange, req) -> Mono.just(new GetPromptResult("Test prompt description", List + .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content")))))); + + StepVerifier.create(serverWithoutPrompts.addPrompt(specification)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(IllegalStateException.class) + .hasMessage("Server must be configured with prompt capabilities"); + }); + } + + @Test + void testRemovePromptWithoutCapability() { + // Create a server without prompt capabilities + McpAsyncServer serverWithoutPrompts = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + StepVerifier.create(serverWithoutPrompts.removePrompt(TEST_PROMPT_NAME)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(IllegalStateException.class) + .hasMessage("Server must be configured with prompt capabilities"); + }); + } + + @Test + void testRemovePrompt() { + String TEST_PROMPT_NAME_TO_REMOVE = "TEST_PROMPT_NAME678"; + + Prompt prompt = new Prompt(TEST_PROMPT_NAME_TO_REMOVE, "Test Prompt", "Test Prompt", List.of()); + McpServerFeatures.AsyncPromptSpecification specification = new McpServerFeatures.AsyncPromptSpecification( + prompt, (exchange, req) -> Mono.just(new GetPromptResult("Test prompt description", List + .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content")))))); + + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().prompts(true).build()) + .prompts(specification) + .build(); + + StepVerifier.create(mcpAsyncServer.removePrompt(TEST_PROMPT_NAME_TO_REMOVE)).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testRemoveNonexistentPrompt() { + var mcpAsyncServer2 = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().prompts(true).build()) + .build(); + + StepVerifier.create(mcpAsyncServer2.removePrompt("nonexistent-prompt")).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer2.closeGracefully().block(Duration.ofSeconds(10))) + .doesNotThrowAnyException(); + } + + // --------------------------------------- + // Roots Tests + // --------------------------------------- + + @Test + void testRootsChangeHandlers() { + // Test with single consumer + var rootsReceived = new McpSchema.Root[1]; + var consumerCalled = new boolean[1]; + + var singleConsumerServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .rootsChangeHandlers(List.of((exchange, roots) -> Mono.fromRunnable(() -> { + consumerCalled[0] = true; + if (!roots.isEmpty()) { + rootsReceived[0] = roots.get(0); + } + }))) + .build(); + + assertThat(singleConsumerServer).isNotNull(); + assertThatCode(() -> singleConsumerServer.closeGracefully().block(Duration.ofSeconds(10))) + .doesNotThrowAnyException(); + onClose(); + + // Test with multiple consumers + var consumer1Called = new boolean[1]; + var consumer2Called = new boolean[1]; + var rootsContent = new List[1]; + + var multipleConsumersServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .rootsChangeHandlers(List.of((exchange, roots) -> Mono.fromRunnable(() -> { + consumer1Called[0] = true; + rootsContent[0] = roots; + }), (exchange, roots) -> Mono.fromRunnable(() -> consumer2Called[0] = true))) + .build(); + + assertThat(multipleConsumersServer).isNotNull(); + assertThatCode(() -> multipleConsumersServer.closeGracefully().block(Duration.ofSeconds(10))) + .doesNotThrowAnyException(); + onClose(); + + // Test error handling + var errorHandlingServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .rootsChangeHandlers(List.of((exchange, roots) -> { + throw new RuntimeException("Test error"); + })) + .build(); + + assertThat(errorHandlingServer).isNotNull(); + assertThatCode(() -> errorHandlingServer.closeGracefully().block(Duration.ofSeconds(10))) + .doesNotThrowAnyException(); + onClose(); + + // Test without consumers + var noConsumersServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + assertThat(noConsumersServer).isNotNull(); + assertThatCode(() -> noConsumersServer.closeGracefully().block(Duration.ofSeconds(10))) + .doesNotThrowAnyException(); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/AbstractMcpClientServerIntegrationTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/AbstractMcpClientServerIntegrationTests.java new file mode 100644 index 000000000..1f5387f37 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/AbstractMcpClientServerIntegrationTests.java @@ -0,0 +1,1756 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.stream.Collectors; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; +import io.modelcontextprotocol.spec.McpSchema.CompleteRequest; +import io.modelcontextprotocol.spec.McpSchema.CompleteResult; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.ElicitRequest; +import io.modelcontextprotocol.spec.McpSchema.ElicitResult; +import io.modelcontextprotocol.spec.McpSchema.InitializeResult; +import io.modelcontextprotocol.spec.McpSchema.ModelPreferences; +import io.modelcontextprotocol.spec.McpSchema.Prompt; +import io.modelcontextprotocol.spec.McpSchema.PromptArgument; +import io.modelcontextprotocol.spec.McpSchema.PromptReference; +import io.modelcontextprotocol.spec.McpSchema.Role; +import io.modelcontextprotocol.spec.McpSchema.Root; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.TextContent; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.util.Utils; +import net.javacrumbs.jsonunit.core.Option; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import static io.modelcontextprotocol.util.ToolsUtils.EMPTY_JSON_SCHEMA; +import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson; +import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertWith; +import static org.awaitility.Awaitility.await; +import static org.mockito.Mockito.mock; + +public abstract class AbstractMcpClientServerIntegrationTests { + + protected ConcurrentHashMap clientBuilders = new ConcurrentHashMap<>(); + + abstract protected void prepareClients(int port, String mcpEndpoint); + + abstract protected McpServer.AsyncSpecification prepareAsyncServerBuilder(); + + abstract protected McpServer.SyncSpecification prepareSyncServerBuilder(); + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void simple(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var server = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(1000)) + .build(); + try ( + // Create client without sampling capabilities + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) + .requestTimeout(Duration.ofSeconds(1000)) + .build()) { + + assertThat(client.initialize()).isNotNull(); + + } + finally { + server.closeGracefully().block(); + } + } + + // --------------------------------------- + // Sampling Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testCreateMessageWithoutSamplingCapabilities(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) + .callHandler((exchange, request) -> { + return exchange.createMessage(mock(McpSchema.CreateMessageRequest.class)) + .then(Mono.just(mock(CallToolResult.class))); + }) + .build(); + + var server = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").tools(tool).build(); + + try ( + // Create client without sampling capabilities + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) + .build()) { + + assertThat(client.initialize()).isNotNull(); + + try { + client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be configured with sampling capabilities"); + } + } + finally { + server.closeGracefully().block(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testCreateMessageSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + Function samplingHandler = request -> { + assertThat(request.messages()).hasSize(1); + assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); + + return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", + CreateMessageResult.StopReason.STOP_SEQUENCE); + }; + + CallToolResult callResponse = McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent("CALL RESPONSE")) + .build(); + + AtomicReference samplingResult = new AtomicReference<>(); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) + .callHandler((exchange, request) -> { + + var createMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("Test message")))) + .modelPreferences(ModelPreferences.builder() + .hints(List.of()) + .costPriority(1.0) + .speedPriority(1.0) + .intelligencePriority(1.0) + .build()) + .build(); + + return exchange.createMessage(createMessageRequest) + .doOnNext(samplingResult::set) + .thenReturn(callResponse); + }) + .build(); + + var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").tools(tool).build(); + + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + assertWith(samplingResult.get(), result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }); + } + finally { + mcpServer.closeGracefully().block(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testCreateMessageWithRequestTimeoutSuccess(String clientType) throws InterruptedException { + + // Client + + var clientBuilder = clientBuilders.get(clientType); + + Function samplingHandler = request -> { + assertThat(request.messages()).hasSize(1); + assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); + try { + TimeUnit.SECONDS.sleep(2); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", + CreateMessageResult.StopReason.STOP_SEQUENCE); + }; + + // Server + + CallToolResult callResponse = McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent("CALL RESPONSE")) + .build(); + + AtomicReference samplingResult = new AtomicReference<>(); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) + .callHandler((exchange, request) -> { + + var createMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("Test message")))) + .modelPreferences(ModelPreferences.builder() + .hints(List.of()) + .costPriority(1.0) + .speedPriority(1.0) + .intelligencePriority(1.0) + .build()) + .build(); + + return exchange.createMessage(createMessageRequest) + .doOnNext(samplingResult::set) + .thenReturn(callResponse); + }) + .build(); + + var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(4)) + .tools(tool) + .build(); + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + assertWith(samplingResult.get(), result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }); + } + finally { + mcpServer.closeGracefully().block(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testCreateMessageWithRequestTimeoutFail(String clientType) throws InterruptedException { + + var clientBuilder = clientBuilders.get(clientType); + + Function samplingHandler = request -> { + assertThat(request.messages()).hasSize(1); + assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); + try { + TimeUnit.SECONDS.sleep(2); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", + CreateMessageResult.StopReason.STOP_SEQUENCE); + }; + + CallToolResult callResponse = McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent("CALL RESPONSE")) + .build(); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) + .callHandler((exchange, request) -> { + + var createMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("Test message")))) + .modelPreferences(ModelPreferences.builder() + .hints(List.of()) + .costPriority(1.0) + .speedPriority(1.0) + .intelligencePriority(1.0) + .build()) + .build(); + + return exchange.createMessage(createMessageRequest).thenReturn(callResponse); + }) + .build(); + + var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(1)) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThatExceptionOfType(McpError.class).isThrownBy(() -> { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + }).withMessageContaining("1000ms"); + } + finally { + mcpServer.closeGracefully().block(); + } + } + + // --------------------------------------- + // Elicitation Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testCreateElicitationWithoutElicitationCapabilities(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) + .callHandler((exchange, request) -> exchange.createElicitation(mock(ElicitRequest.class)) + .then(Mono.just(mock(CallToolResult.class)))) + .build(); + + var server = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").tools(tool).build(); + + // Create client without elicitation capabilities + try (var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")).build()) { + + assertThat(client.initialize()).isNotNull(); + + try { + client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be configured with elicitation capabilities"); + } + } + finally { + server.closeGracefully().block(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testCreateElicitationSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + + return new McpSchema.ElicitResult(McpSchema.ElicitResult.Action.ACCEPT, + Map.of("message", request.message())); + }; + + CallToolResult callResponse = McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent("CALL RESPONSE")) + .build(); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) + .callHandler((exchange, request) -> { + + var elicitationRequest = McpSchema.ElicitRequest.builder() + .message("Test message") + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }).verifyComplete(); + + return Mono.just(callResponse); + }) + .build(); + + var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").tools(tool).build(); + + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + } + finally { + mcpServer.closeGracefully().block(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testCreateElicitationWithRequestTimeoutSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); + }; + + CallToolResult callResponse = McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent("CALL RESPONSE")) + .build(); + + AtomicReference resultRef = new AtomicReference<>(); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) + .callHandler((exchange, request) -> { + + var elicitationRequest = McpSchema.ElicitRequest.builder() + .message("Test message") + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + return exchange.createElicitation(elicitationRequest) + .doOnNext(resultRef::set) + .then(Mono.just(callResponse)); + }) + .build(); + + var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(3)) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + assertWith(resultRef.get(), result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }); + } + finally { + mcpServer.closeGracefully().block(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testCreateElicitationWithRequestTimeoutFail(String clientType) { + + var latch = new CountDownLatch(1); + + var clientBuilder = clientBuilders.get(clientType); + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + + try { + if (!latch.await(2, TimeUnit.SECONDS)) { + throw new RuntimeException("Timeout waiting for elicitation processing"); + } + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); + }; + + CallToolResult callResponse = CallToolResult.builder().addContent(new TextContent("CALL RESPONSE")).build(); + + AtomicReference resultRef = new AtomicReference<>(); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) + .callHandler((exchange, request) -> { + + var elicitationRequest = ElicitRequest.builder() + .message("Test message") + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + return exchange.createElicitation(elicitationRequest) + .doOnNext(resultRef::set) + .then(Mono.just(callResponse)); + }) + .build(); + + var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(1)) // 1 second. + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThatExceptionOfType(McpError.class).isThrownBy(() -> { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + }).withMessageContaining("within 1000ms"); + + ElicitResult elicitResult = resultRef.get(); + assertThat(elicitResult).isNull(); + } + finally { + mcpServer.closeGracefully().block(); + } + } + + // --------------------------------------- + // Roots Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testRootsSuccess(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + List roots = List.of(new Root("uri1://", "root1"), new Root("uri2://", "root2")); + + AtomicReference> rootsRef = new AtomicReference<>(); + + var mcpServer = prepareSyncServerBuilder() + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) + .build(); + + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(roots) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(rootsRef.get()).isNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(roots); + }); + + // Remove a root + mcpClient.removeRoot(roots.get(0).uri()); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(roots.get(1))); + }); + + // Add a new root + var root3 = new Root("uri3://", "root3"); + mcpClient.addRoot(root3); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(roots.get(1), root3)); + }); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testRootsWithoutCapability(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + McpServerFeatures.SyncToolSpecification tool = McpServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) + .callHandler((exchange, request) -> { + + exchange.listRoots(); // try to list roots + + return mock(CallToolResult.class); + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().rootsChangeHandler((exchange, rootsUpdate) -> { + }).tools(tool).build(); + + try ( + // Create client without roots capability + // No roots capability + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()).build()) { + + assertThat(mcpClient.initialize()).isNotNull(); + + // Attempt to list roots should fail + try { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class).hasMessage("Roots not supported"); + } + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testRootsNotificationWithEmptyRootsList(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + AtomicReference> rootsRef = new AtomicReference<>(); + + var mcpServer = prepareSyncServerBuilder() + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) + .build(); + + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(List.of()) // Empty roots list + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).isEmpty(); + }); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testRootsWithMultipleHandlers(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + List roots = List.of(new Root("uri1://", "root1")); + + AtomicReference> rootsRef1 = new AtomicReference<>(); + AtomicReference> rootsRef2 = new AtomicReference<>(); + + var mcpServer = prepareSyncServerBuilder() + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef1.set(rootsUpdate)) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef2.set(rootsUpdate)) + .build(); + + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(roots) + .build()) { + + assertThat(mcpClient.initialize()).isNotNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef1.get()).containsAll(roots); + assertThat(rootsRef2.get()).containsAll(roots); + }); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testRootsServerCloseWithActiveSubscription(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + List roots = List.of(new Root("uri1://", "root1")); + + AtomicReference> rootsRef = new AtomicReference<>(); + + var mcpServer = prepareSyncServerBuilder() + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) + .build(); + + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(roots) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(roots); + }); + } + finally { + mcpServer.closeGracefully(); + } + } + + // --------------------------------------- + // Tools Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testToolCallSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var responseBodyIsNullOrBlank = new AtomicBoolean(false); + var callResponse = McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent("CALL RESPONSE; ctx=importantValue")) + .build(); + McpServerFeatures.SyncToolSpecification tool1 = McpServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) + .callHandler((exchange, request) -> { + + try { + HttpResponse response = HttpClient.newHttpClient() + .send(HttpRequest.newBuilder() + .uri(URI.create( + "https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md")) + .GET() + .build(), HttpResponse.BodyHandlers.ofString()); + String responseBody = response.body(); + responseBodyIsNullOrBlank.set(!Utils.hasText(responseBody)); + } + catch (Exception e) { + e.printStackTrace(); + } + + return callResponse; + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(responseBodyIsNullOrBlank.get()).isFalse(); + assertThat(response).isNotNull().isEqualTo(callResponse); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testThrowingToolCallIsCaughtBeforeTimeout(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + McpSyncServer mcpServer = prepareSyncServerBuilder() + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(McpServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder() + .name("tool1") + .description("tool1 description") + .inputSchema(EMPTY_JSON_SCHEMA) + .build()) + .callHandler((exchange, request) -> { + // We trigger a timeout on blocking read, raising an exception + Mono.never().block(Duration.ofSeconds(1)); + return null; + }) + .build()) + .build(); + + try (var mcpClient = clientBuilder.requestTimeout(Duration.ofMillis(6666)).build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // We expect the tool call to fail immediately with the exception raised by + // the offending tool instead of getting back a timeout. + assertThatExceptionOfType(McpError.class) + .isThrownBy(() -> mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of()))) + .withMessageContaining("Timeout on blocking read"); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testToolCallSuccessWithTransportContextExtraction(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var transportContextIsNull = new AtomicBoolean(false); + var transportContextIsEmpty = new AtomicBoolean(false); + var responseBodyIsNullOrBlank = new AtomicBoolean(false); + + var expectedCallResponse = McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent("CALL RESPONSE; ctx=value")) + .build(); + McpServerFeatures.SyncToolSpecification tool1 = McpServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) + .callHandler((exchange, request) -> { + + McpTransportContext transportContext = exchange.transportContext(); + transportContextIsNull.set(transportContext == null); + transportContextIsEmpty.set(transportContext.equals(McpTransportContext.EMPTY)); + String ctxValue = (String) transportContext.get("important"); + + try { + String responseBody = "TOOL RESPONSE"; + responseBodyIsNullOrBlank.set(!Utils.hasText(responseBody)); + } + catch (Exception e) { + e.printStackTrace(); + } + + return McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent("CALL RESPONSE; ctx=" + ctxValue)) + .build(); + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(transportContextIsNull.get()).isFalse(); + assertThat(transportContextIsEmpty.get()).isFalse(); + assertThat(responseBodyIsNullOrBlank.get()).isFalse(); + assertThat(response).isNotNull().isEqualTo(expectedCallResponse); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testToolListChangeHandlingSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var callResponse = McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent("CALL RESPONSE")) + .build(); + + McpServerFeatures.SyncToolSpecification tool1 = McpServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) + .callHandler((exchange, request) -> { + // perform a blocking call to a remote service + try { + HttpResponse response = HttpClient.newHttpClient() + .send(HttpRequest.newBuilder() + .uri(URI.create( + "https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md")) + .GET() + .build(), HttpResponse.BodyHandlers.ofString()); + String responseBody = response.body(); + assertThat(responseBody).isNotBlank(); + } + catch (Exception e) { + e.printStackTrace(); + } + return callResponse; + }) + .build(); + + AtomicReference> toolsRef = new AtomicReference<>(); + + var mcpServer = prepareSyncServerBuilder().capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + try (var mcpClient = clientBuilder.toolsChangeConsumer(toolsUpdate -> { + // perform a blocking call to a remote service + try { + HttpResponse response = HttpClient.newHttpClient() + .send(HttpRequest.newBuilder() + .uri(URI.create( + "https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md")) + .GET() + .build(), HttpResponse.BodyHandlers.ofString()); + String responseBody = response.body(); + assertThat(responseBody).isNotBlank(); + toolsRef.set(toolsUpdate); + } + catch (Exception e) { + e.printStackTrace(); + } + }).build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(toolsRef.get()).isNull(); + + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + + mcpServer.notifyToolsListChanged(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(toolsRef.get()).containsAll(List.of(tool1.tool())); + }); + + // Remove a tool + mcpServer.removeTool("tool1"); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(toolsRef.get()).isEmpty(); + }); + + // Add a new tool + McpServerFeatures.SyncToolSpecification tool2 = McpServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder() + .name("tool2") + .description("tool2 description") + .inputSchema(EMPTY_JSON_SCHEMA) + .build()) + .callHandler((exchange, request) -> callResponse) + .build(); + + mcpServer.addTool(tool2); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(toolsRef.get()).containsAll(List.of(tool2.tool())); + }); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testInitialize(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var mcpServer = prepareSyncServerBuilder().build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + } + finally { + mcpServer.closeGracefully(); + } + } + + // --------------------------------------- + // Logging Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testLoggingNotification(String clientType) throws InterruptedException { + int expectedNotificationsCount = 3; + CountDownLatch latch = new CountDownLatch(expectedNotificationsCount); + // Create a list to store received logging notifications + List receivedNotifications = new CopyOnWriteArrayList<>(); + + var clientBuilder = clientBuilders.get(clientType); + + // Create server with a tool that sends logging notifications + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder() + .name("logging-test") + .description("Test logging notifications") + .inputSchema(EMPTY_JSON_SCHEMA) + .build()) + .callHandler((exchange, request) -> { + + // Create and send notifications with different levels + + //@formatter:off + return exchange // This should be filtered out (DEBUG < NOTICE) + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.DEBUG) + .logger("test-logger") + .data("Debug message") + .build()) + .then(exchange // This should be sent (NOTICE >= NOTICE) + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.NOTICE) + .logger("test-logger") + .data("Notice message") + .build())) + .then(exchange // This should be sent (ERROR > NOTICE) + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.ERROR) + .logger("test-logger") + .data("Error message") + .build())) + .then(exchange // This should be filtered out (INFO < NOTICE) + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.INFO) + .logger("test-logger") + .data("Another info message") + .build())) + .then(exchange // This should be sent (ERROR >= NOTICE) + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.ERROR) + .logger("test-logger") + .data("Another error message") + .build())) + .thenReturn(CallToolResult.builder() + .content(List.of(new McpSchema.TextContent("Logging test completed"))) + .isError(false) + .build()); + //@formatter:on + }) + .build(); + + var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try ( + // Create client with logging notification handler + var mcpClient = clientBuilder.loggingConsumer(notification -> { + receivedNotifications.add(notification); + latch.countDown(); + }).build()) { + + // Initialize client + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Set minimum logging level to NOTICE + mcpClient.setLoggingLevel(McpSchema.LoggingLevel.NOTICE); + + // Call the tool that sends logging notifications + CallToolResult result = mcpClient.callTool(new McpSchema.CallToolRequest("logging-test", Map.of())); + assertThat(result).isNotNull(); + assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Logging test completed"); + + assertThat(latch.await(5, TimeUnit.SECONDS)).as("Should receive notifications in reasonable time").isTrue(); + + // Should have received 3 notifications (1 NOTICE and 2 ERROR) + assertThat(receivedNotifications).hasSize(expectedNotificationsCount); + + Map notificationMap = receivedNotifications.stream() + .collect(Collectors.toMap(n -> n.data(), n -> n)); + + // First notification should be NOTICE level + assertThat(notificationMap.get("Notice message").level()).isEqualTo(McpSchema.LoggingLevel.NOTICE); + assertThat(notificationMap.get("Notice message").logger()).isEqualTo("test-logger"); + assertThat(notificationMap.get("Notice message").data()).isEqualTo("Notice message"); + + // Second notification should be ERROR level + assertThat(notificationMap.get("Error message").level()).isEqualTo(McpSchema.LoggingLevel.ERROR); + assertThat(notificationMap.get("Error message").logger()).isEqualTo("test-logger"); + assertThat(notificationMap.get("Error message").data()).isEqualTo("Error message"); + + // Third notification should be ERROR level + assertThat(notificationMap.get("Another error message").level()).isEqualTo(McpSchema.LoggingLevel.ERROR); + assertThat(notificationMap.get("Another error message").logger()).isEqualTo("test-logger"); + assertThat(notificationMap.get("Another error message").data()).isEqualTo("Another error message"); + } + finally { + mcpServer.closeGracefully().block(); + } + } + + // --------------------------------------- + // Progress Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testProgressNotification(String clientType) throws InterruptedException { + int expectedNotificationsCount = 4; // 3 notifications + 1 for another progress + // token + CountDownLatch latch = new CountDownLatch(expectedNotificationsCount); + // Create a list to store received logging notifications + List receivedNotifications = new CopyOnWriteArrayList<>(); + + var clientBuilder = clientBuilders.get(clientType); + + // Create server with a tool that sends logging notifications + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(McpSchema.Tool.builder() + .name("progress-test") + .description("Test progress notifications") + .inputSchema(EMPTY_JSON_SCHEMA) + .build()) + .callHandler((exchange, request) -> { + + // Create and send notifications + var progressToken = (String) request.meta().get("progressToken"); + + return exchange + .progressNotification( + new McpSchema.ProgressNotification(progressToken, 0.0, 1.0, "Processing started")) + .then(exchange.progressNotification( + new McpSchema.ProgressNotification(progressToken, 0.5, 1.0, "Processing data"))) + .then(// Send a progress notification with another progress value + // should + exchange.progressNotification(new McpSchema.ProgressNotification("another-progress-token", + 0.0, 1.0, "Another processing started"))) + .then(exchange.progressNotification( + new McpSchema.ProgressNotification(progressToken, 1.0, 1.0, "Processing completed"))) + .thenReturn(CallToolResult.builder() + .content(List.of(new McpSchema.TextContent("Progress test completed"))) + .isError(false) + .build()); + }) + .build(); + + var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try ( + // Create client with progress notification handler + var mcpClient = clientBuilder.progressConsumer(notification -> { + receivedNotifications.add(notification); + latch.countDown(); + }).build()) { + + // Initialize client + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Call the tool that sends progress notifications + McpSchema.CallToolRequest callToolRequest = McpSchema.CallToolRequest.builder() + .name("progress-test") + .meta(Map.of("progressToken", "test-progress-token")) + .build(); + CallToolResult result = mcpClient.callTool(callToolRequest); + assertThat(result).isNotNull(); + assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Progress test completed"); + + assertThat(latch.await(5, TimeUnit.SECONDS)).as("Should receive notifications in reasonable time").isTrue(); + + // Should have received 3 notifications + assertThat(receivedNotifications).hasSize(expectedNotificationsCount); + + Map notificationMap = receivedNotifications.stream() + .collect(Collectors.toMap(n -> n.message(), n -> n)); + + // First notification should be 0.0/1.0 progress + assertThat(notificationMap.get("Processing started").progressToken()).isEqualTo("test-progress-token"); + assertThat(notificationMap.get("Processing started").progress()).isEqualTo(0.0); + assertThat(notificationMap.get("Processing started").total()).isEqualTo(1.0); + assertThat(notificationMap.get("Processing started").message()).isEqualTo("Processing started"); + + // Second notification should be 0.5/1.0 progress + assertThat(notificationMap.get("Processing data").progressToken()).isEqualTo("test-progress-token"); + assertThat(notificationMap.get("Processing data").progress()).isEqualTo(0.5); + assertThat(notificationMap.get("Processing data").total()).isEqualTo(1.0); + assertThat(notificationMap.get("Processing data").message()).isEqualTo("Processing data"); + + // Third notification should be another progress token with 0.0/1.0 progress + assertThat(notificationMap.get("Another processing started").progressToken()) + .isEqualTo("another-progress-token"); + assertThat(notificationMap.get("Another processing started").progress()).isEqualTo(0.0); + assertThat(notificationMap.get("Another processing started").total()).isEqualTo(1.0); + assertThat(notificationMap.get("Another processing started").message()) + .isEqualTo("Another processing started"); + + // Fourth notification should be 1.0/1.0 progress + assertThat(notificationMap.get("Processing completed").progressToken()).isEqualTo("test-progress-token"); + assertThat(notificationMap.get("Processing completed").progress()).isEqualTo(1.0); + assertThat(notificationMap.get("Processing completed").total()).isEqualTo(1.0); + assertThat(notificationMap.get("Processing completed").message()).isEqualTo("Processing completed"); + } + finally { + mcpServer.closeGracefully().block(); + } + } + + // --------------------------------------- + // Completion Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : Completion call") + @MethodSource("clientsForTesting") + void testCompletionShouldReturnExpectedSuggestions(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + var expectedValues = List.of("python", "pytorch", "pyside"); + var completionResponse = new McpSchema.CompleteResult(new CompleteResult.CompleteCompletion(expectedValues, 10, // total + true // hasMore + )); + + AtomicReference samplingRequest = new AtomicReference<>(); + BiFunction completionHandler = (mcpSyncServerExchange, + request) -> { + samplingRequest.set(request); + return completionResponse; + }; + + var mcpServer = prepareSyncServerBuilder().capabilities(ServerCapabilities.builder().completions().build()) + .prompts(new McpServerFeatures.SyncPromptSpecification( + new Prompt("code_review", "Code review", "this is code review prompt", + List.of(new PromptArgument("language", "Language", "string", false))), + (mcpSyncServerExchange, getPromptRequest) -> null)) + .completions(new McpServerFeatures.SyncCompletionSpecification( + new McpSchema.PromptReference(PromptReference.TYPE, "code_review", "Code review"), + completionHandler)) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CompleteRequest request = new CompleteRequest( + new PromptReference(PromptReference.TYPE, "code_review", "Code review"), + new CompleteRequest.CompleteArgument("language", "py")); + + CompleteResult result = mcpClient.completeCompletion(request); + + assertThat(result).isNotNull(); + + assertThat(samplingRequest.get().argument().name()).isEqualTo("language"); + assertThat(samplingRequest.get().argument().value()).isEqualTo("py"); + assertThat(samplingRequest.get().ref().type()).isEqualTo(PromptReference.TYPE); + } + finally { + mcpServer.closeGracefully(); + } + } + + // --------------------------------------- + // Ping Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testPingSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + // Create server with a tool that uses ping functionality + AtomicReference executionOrder = new AtomicReference<>(""); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder() + .name("ping-async-test") + .description("Test ping async behavior") + .inputSchema(EMPTY_JSON_SCHEMA) + .build()) + .callHandler((exchange, request) -> { + + executionOrder.set(executionOrder.get() + "1"); + + // Test async ping behavior + return exchange.ping().doOnNext(result -> { + + assertThat(result).isNotNull(); + // Ping should return an empty object or map + assertThat(result).isInstanceOf(Map.class); + + executionOrder.set(executionOrder.get() + "2"); + assertThat(result).isNotNull(); + }).then(Mono.fromCallable(() -> { + executionOrder.set(executionOrder.get() + "3"); + return CallToolResult.builder() + .content(List.of(new McpSchema.TextContent("Async ping test completed"))) + .isError(false) + .build(); + })); + }) + .build(); + + var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + // Initialize client + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Call the tool that tests ping async behavior + CallToolResult result = mcpClient.callTool(new McpSchema.CallToolRequest("ping-async-test", Map.of())); + assertThat(result).isNotNull(); + assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Async ping test completed"); + + // Verify execution order + assertThat(executionOrder.get()).isEqualTo("123"); + } + finally { + mcpServer.closeGracefully().block(); + } + } + + // --------------------------------------- + // Tool Structured Output Schema Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testStructuredOutputValidationSuccess(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of( + "type", "object", "properties", Map.of("result", Map.of("type", "number"), "operation", + Map.of("type", "string"), "timestamp", Map.of("type", "string")), + "required", List.of("result", "operation")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + McpServerFeatures.SyncToolSpecification tool = McpServerFeatures.SyncToolSpecification.builder() + .tool(calculatorTool) + .callHandler((exchange, request) -> { + String expression = (String) request.arguments().getOrDefault("expression", "2 + 3"); + double result = evaluateExpression(expression); + return CallToolResult.builder() + .structuredContent( + Map.of("result", result, "operation", expression, "timestamp", "2024-01-01T10:00:00Z")) + .build(); + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Verify tool is listed with output schema + var toolsList = mcpClient.listTools(); + assertThat(toolsList.tools()).hasSize(1); + assertThat(toolsList.tools().get(0).name()).isEqualTo("calculator"); + // Note: outputSchema might be null in sync server, but validation still works + + // Call tool with valid structured output + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isFalse(); + + // In WebMVC, structured content is returned properly + if (response.structuredContent() != null) { + assertThat((Map) response.structuredContent()).containsEntry("result", 5.0) + .containsEntry("operation", "2 + 3") + .containsEntry("timestamp", "2024-01-01T10:00:00Z"); + } + else { + // Fallback to checking content if structured content is not available + assertThat(response.content()).isNotEmpty(); + } + + assertThat(response.structuredContent()).isNotNull(); + assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"result":5.0,"operation":"2 + 3","timestamp":"2024-01-01T10:00:00Z"}""")); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testStructuredOutputOfObjectArrayValidationSuccess(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema that returns an array of objects + Map outputSchema = Map + .of( // @formatter:off + "type", "array", + "items", Map.of( + "type", "object", + "properties", Map.of( + "name", Map.of("type", "string"), + "age", Map.of("type", "number")), + "required", List.of("name", "age"))); // @formatter:on + + Tool calculatorTool = Tool.builder() + .name("getMembers") + .description("Returns a list of members") + .outputSchema(outputSchema) + .build(); + + McpServerFeatures.SyncToolSpecification tool = McpServerFeatures.SyncToolSpecification.builder() + .tool(calculatorTool) + .callHandler((exchange, request) -> { + return CallToolResult.builder() + .structuredContent(List.of(Map.of("name", "John", "age", 30), Map.of("name", "Peter", "age", 25))) + .build(); + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + assertThat(mcpClient.initialize()).isNotNull(); + + // Call tool with valid structured output of type array + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("getMembers", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isFalse(); + + assertThat(response.structuredContent()).isNotNull(); + assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isArray() + .hasSize(2) + .containsExactlyInAnyOrder(json(""" + {"name":"John","age":30}"""), json(""" + {"name":"Peter","age":25}""")); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testStructuredOutputWithInHandlerError(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of( + "type", "object", "properties", Map.of("result", Map.of("type", "number"), "operation", + Map.of("type", "string"), "timestamp", Map.of("type", "string")), + "required", List.of("result", "operation")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + // Handler that returns an error result + McpServerFeatures.SyncToolSpecification tool = McpServerFeatures.SyncToolSpecification.builder() + .tool(calculatorTool) + .callHandler((exchange, request) -> CallToolResult.builder() + .isError(true) + .content(List.of(new TextContent("Error calling tool: Simulated in-handler error"))) + .build()) + .build(); + + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Verify tool is listed with output schema + var toolsList = mcpClient.listTools(); + assertThat(toolsList.tools()).hasSize(1); + assertThat(toolsList.tools().get(0).name()).isEqualTo("calculator"); + // Note: outputSchema might be null in sync server, but validation still works + + // Call tool with valid structured output + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isTrue(); + assertThat(response.content()).isNotEmpty(); + assertThat(response.content()) + .containsExactly(new McpSchema.TextContent("Error calling tool: Simulated in-handler error")); + assertThat(response.structuredContent()).isNull(); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testStructuredOutputValidationFailure(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of("type", "object", "properties", + Map.of("result", Map.of("type", "number"), "operation", Map.of("type", "string")), "required", + List.of("result", "operation")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + McpServerFeatures.SyncToolSpecification tool = McpServerFeatures.SyncToolSpecification.builder() + .tool(calculatorTool) + .callHandler((exchange, request) -> { + // Return invalid structured output. Result should be number, missing + // operation + return CallToolResult.builder() + .addTextContent("Invalid calculation") + .structuredContent(Map.of("result", "not-a-number", "extra", "field")) + .build(); + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Call tool with invalid structured output + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isTrue(); + assertThat(response.content()).hasSize(1); + assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + + String errorMessage = ((McpSchema.TextContent) response.content().get(0)).text(); + assertThat(errorMessage).contains("Validation failed"); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testStructuredOutputMissingStructuredContent(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of("type", "object", "properties", + Map.of("result", Map.of("type", "number")), "required", List.of("result")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + McpServerFeatures.SyncToolSpecification tool = McpServerFeatures.SyncToolSpecification.builder() + .tool(calculatorTool) + .callHandler((exchange, request) -> { + // Return result without structured content but tool has output schema + return CallToolResult.builder().addTextContent("Calculation completed").build(); + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Call tool that should return structured content but doesn't + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isTrue(); + assertThat(response.content()).hasSize(1); + assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + + String errorMessage = ((McpSchema.TextContent) response.content().get(0)).text(); + assertThat(errorMessage).isEqualTo( + "Response missing structured content which is expected when calling tool with non-empty outputSchema"); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testStructuredOutputRuntimeToolAddition(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + // Start server without tools + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Initially no tools + assertThat(mcpClient.listTools().tools()).isEmpty(); + + // Add tool with output schema at runtime + Map outputSchema = Map.of("type", "object", "properties", + Map.of("message", Map.of("type", "string"), "count", Map.of("type", "integer")), "required", + List.of("message", "count")); + + Tool dynamicTool = Tool.builder() + .name("dynamic-tool") + .description("Dynamically added tool") + .outputSchema(outputSchema) + .build(); + + McpServerFeatures.SyncToolSpecification toolSpec = McpServerFeatures.SyncToolSpecification.builder() + .tool(dynamicTool) + .callHandler((exchange, request) -> { + int count = (Integer) request.arguments().getOrDefault("count", 1); + return CallToolResult.builder() + .addTextContent("Dynamic tool executed " + count + " times") + .structuredContent(Map.of("message", "Dynamic execution", "count", count)) + .build(); + }) + .build(); + + // Add tool to server + mcpServer.addTool(toolSpec); + + // Wait for tool list change notification + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(mcpClient.listTools().tools()).hasSize(1); + }); + + // Verify tool was added with output schema + var toolsList = mcpClient.listTools(); + assertThat(toolsList.tools()).hasSize(1); + assertThat(toolsList.tools().get(0).name()).isEqualTo("dynamic-tool"); + // Note: outputSchema might be null in sync server, but validation still works + + // Call dynamically added tool + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("dynamic-tool", Map.of("count", 3))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isFalse(); + + assertThat(response.content()).hasSize(1); + assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) response.content().get(0)).text()) + .isEqualTo("Dynamic tool executed 3 times"); + + assertThat(response.structuredContent()).isNotNull(); + assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"count":3,"message":"Dynamic execution"}""")); + } + finally { + mcpServer.closeGracefully(); + } + } + + private double evaluateExpression(String expression) { + // Simple expression evaluator for testing + return switch (expression) { + case "2 + 3" -> 5.0; + case "10 * 2" -> 20.0; + case "7 + 8" -> 15.0; + case "5 + 3" -> 8.0; + default -> 0.0; + }; + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java new file mode 100644 index 000000000..915c658e3 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java @@ -0,0 +1,678 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import java.util.List; + +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; +import io.modelcontextprotocol.spec.McpSchema.Prompt; +import io.modelcontextprotocol.spec.McpSchema.PromptMessage; +import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; +import io.modelcontextprotocol.spec.McpSchema.Resource; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import static io.modelcontextprotocol.util.ToolsUtils.EMPTY_JSON_SCHEMA; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Test suite for the {@link McpSyncServer} that can be used with different + * {@link McpServerTransportProvider} implementations. + * + * @author Christian Tzolov + */ +// KEEP IN SYNC with the class in mcp-test module +public abstract class AbstractMcpSyncServerTests { + + private static final String TEST_TOOL_NAME = "test-tool"; + + private static final String TEST_RESOURCE_URI = "test://resource"; + + private static final String TEST_PROMPT_NAME = "test-prompt"; + + abstract protected McpServer.SyncSpecification prepareSyncServerBuilder(); + + protected void onStart() { + } + + protected void onClose() { + } + + @BeforeEach + void setUp() { + // onStart(); + } + + @AfterEach + void tearDown() { + onClose(); + } + + // --------------------------------------- + // Server Lifecycle Tests + // --------------------------------------- + + @Test + void testConstructorWithInvalidArguments() { + assertThatThrownBy(() -> McpServer.sync((McpServerTransportProvider) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Transport provider must not be null"); + + assertThatThrownBy(() -> prepareSyncServerBuilder().serverInfo(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Server info must not be null"); + } + + @Test + void testGracefulShutdown() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testImmediateClose() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + assertThatCode(mcpSyncServer::close).doesNotThrowAnyException(); + } + + @Test + void testGetAsyncServer() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + assertThat(mcpSyncServer.getAsyncServer()).isNotNull(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Tools Tests + // --------------------------------------- + + @Test + @Deprecated + void testAddTool() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .build(); + + Tool newTool = McpSchema.Tool.builder() + .name("new-tool") + .title("New test tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + assertThatCode(() -> mcpSyncServer.addTool(new McpServerFeatures.SyncToolSpecification(newTool, + (exchange, args) -> CallToolResult.builder().content(List.of()).isError(false).build()))) + .doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testAddToolCall() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .build(); + + Tool newTool = McpSchema.Tool.builder() + .name("new-tool") + .title("New test tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + + assertThatCode(() -> mcpSyncServer.addTool(McpServerFeatures.SyncToolSpecification.builder() + .tool(newTool) + .callHandler((exchange, request) -> CallToolResult.builder().content(List.of()).isError(false).build()) + .build())).doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + @Deprecated + void testAddDuplicateTool() { + Tool duplicateTool = McpSchema.Tool.builder() + .name(TEST_TOOL_NAME) + .title("Duplicate tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tool(duplicateTool, (exchange, args) -> CallToolResult.builder().content(List.of()).isError(false).build()) + .build(); + + assertThatCode(() -> mcpSyncServer.addTool(new McpServerFeatures.SyncToolSpecification(duplicateTool, + (exchange, args) -> CallToolResult.builder().content(List.of()).isError(false).build()))) + .doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testAddDuplicateToolCall() { + Tool duplicateTool = McpSchema.Tool.builder() + .name(TEST_TOOL_NAME) + .title("Duplicate tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .toolCall(duplicateTool, + (exchange, request) -> CallToolResult.builder().content(List.of()).isError(false).build()) + .build(); + + assertThatCode(() -> mcpSyncServer.addTool(McpServerFeatures.SyncToolSpecification.builder() + .tool(duplicateTool) + .callHandler((exchange, request) -> CallToolResult.builder().content(List.of()).isError(false).build()) + .build())).doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testDuplicateToolCallDuringBuilding() { + Tool duplicateTool = McpSchema.Tool.builder() + .name("duplicate-build-toolcall") + .title("Duplicate toolcall during building") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + + assertThatThrownBy(() -> prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .toolCall(duplicateTool, + (exchange, request) -> CallToolResult.builder().content(List.of()).isError(false).build()) + .toolCall(duplicateTool, + (exchange, request) -> CallToolResult.builder().content(List.of()).isError(false).build()) // Duplicate! + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessage("Tool with name 'duplicate-build-toolcall' is already registered."); + } + + @Test + void testDuplicateToolsInBatchListRegistration() { + Tool duplicateTool = McpSchema.Tool.builder() + .name("batch-list-tool") + .title("Duplicate tool in batch list") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + List specs = List.of( + McpServerFeatures.SyncToolSpecification.builder() + .tool(duplicateTool) + .callHandler( + (exchange, request) -> CallToolResult.builder().content(List.of()).isError(false).build()) + .build(), + McpServerFeatures.SyncToolSpecification.builder() + .tool(duplicateTool) + .callHandler( + (exchange, request) -> CallToolResult.builder().content(List.of()).isError(false).build()) + .build() // Duplicate! + ); + + assertThatThrownBy(() -> prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(specs) + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessage("Tool with name 'batch-list-tool' is already registered."); + } + + @Test + void testDuplicateToolsInBatchVarargsRegistration() { + Tool duplicateTool = McpSchema.Tool.builder() + .name("batch-varargs-tool") + .title("Duplicate tool in batch varargs") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + + assertThatThrownBy(() -> prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(McpServerFeatures.SyncToolSpecification.builder() + .tool(duplicateTool) + .callHandler((exchange, request) -> CallToolResult.builder().content(List.of()).isError(false).build()) + .build(), + McpServerFeatures.SyncToolSpecification.builder() + .tool(duplicateTool) + .callHandler((exchange, + request) -> CallToolResult.builder().content(List.of()).isError(false).build()) + .build() // Duplicate! + ) + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessage("Tool with name 'batch-varargs-tool' is already registered."); + } + + @Test + void testRemoveTool() { + Tool tool = McpSchema.Tool.builder() + .name(TEST_TOOL_NAME) + .title("Test tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .toolCall(tool, (exchange, args) -> CallToolResult.builder().content(List.of()).isError(false).build()) + .build(); + + assertThatCode(() -> mcpSyncServer.removeTool(TEST_TOOL_NAME)).doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testRemoveNonexistentTool() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .build(); + + assertThatCode(() -> mcpSyncServer.removeTool("nonexistent-tool")).doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testNotifyToolsListChanged() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + assertThatCode(mcpSyncServer::notifyToolsListChanged).doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Resources Tests + // --------------------------------------- + + @Test + void testNotifyResourcesListChanged() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + assertThatCode(mcpSyncServer::notifyResourcesListChanged).doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testNotifyResourcesUpdated() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + assertThatCode(() -> mcpSyncServer + .notifyResourcesUpdated(new McpSchema.ResourcesUpdatedNotification(TEST_RESOURCE_URI))) + .doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testAddResource() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + Resource resource = Resource.builder() + .uri(TEST_RESOURCE_URI) + .name("Test Resource") + .mimeType("text/plain") + .description("Test resource description") + .build(); + McpServerFeatures.SyncResourceSpecification specification = new McpServerFeatures.SyncResourceSpecification( + resource, (exchange, req) -> new ReadResourceResult(List.of())); + + assertThatCode(() -> mcpSyncServer.addResource(specification)).doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testAddResourceWithNullSpecification() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + assertThatThrownBy(() -> mcpSyncServer.addResource((McpServerFeatures.SyncResourceSpecification) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Resource must not be null"); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testAddResourceWithoutCapability() { + var serverWithoutResources = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + Resource resource = Resource.builder() + .uri(TEST_RESOURCE_URI) + .name("Test Resource") + .mimeType("text/plain") + .description("Test resource description") + .build(); + McpServerFeatures.SyncResourceSpecification specification = new McpServerFeatures.SyncResourceSpecification( + resource, (exchange, req) -> new ReadResourceResult(List.of())); + + assertThatThrownBy(() -> serverWithoutResources.addResource(specification)) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Server must be configured with resource capabilities"); + } + + @Test + void testRemoveResourceWithoutCapability() { + var serverWithoutResources = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + assertThatThrownBy(() -> serverWithoutResources.removeResource(TEST_RESOURCE_URI)) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Server must be configured with resource capabilities"); + } + + @Test + void testListResources() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + Resource resource = Resource.builder() + .uri(TEST_RESOURCE_URI) + .name("Test Resource") + .mimeType("text/plain") + .description("Test resource description") + .build(); + McpServerFeatures.SyncResourceSpecification specification = new McpServerFeatures.SyncResourceSpecification( + resource, (exchange, req) -> new ReadResourceResult(List.of())); + + mcpSyncServer.addResource(specification); + List resources = mcpSyncServer.listResources(); + + assertThat(resources).hasSize(1); + assertThat(resources.get(0).uri()).isEqualTo(TEST_RESOURCE_URI); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testRemoveResource() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + Resource resource = Resource.builder() + .uri(TEST_RESOURCE_URI) + .name("Test Resource") + .mimeType("text/plain") + .description("Test resource description") + .build(); + McpServerFeatures.SyncResourceSpecification specification = new McpServerFeatures.SyncResourceSpecification( + resource, (exchange, req) -> new ReadResourceResult(List.of())); + + mcpSyncServer.addResource(specification); + assertThatCode(() -> mcpSyncServer.removeResource(TEST_RESOURCE_URI)).doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testRemoveNonexistentResource() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + // Removing a non-existent resource should complete successfully (no error) + // as per the new implementation that just logs a warning + assertThatCode(() -> mcpSyncServer.removeResource("nonexistent://resource")).doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Resource Template Tests + // --------------------------------------- + + @Test + void testAddResourceTemplate() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + McpSchema.ResourceTemplate template = McpSchema.ResourceTemplate.builder() + .uriTemplate("test://template/{id}") + .name("test-template") + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.SyncResourceTemplateSpecification specification = new McpServerFeatures.SyncResourceTemplateSpecification( + template, (exchange, req) -> new ReadResourceResult(List.of())); + + assertThatCode(() -> mcpSyncServer.addResourceTemplate(specification)).doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testAddResourceTemplateWithoutCapability() { + // Create a server without resource capabilities + var serverWithoutResources = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + McpSchema.ResourceTemplate template = McpSchema.ResourceTemplate.builder() + .uriTemplate("test://template/{id}") + .name("test-template") + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.SyncResourceTemplateSpecification specification = new McpServerFeatures.SyncResourceTemplateSpecification( + template, (exchange, req) -> new ReadResourceResult(List.of())); + + assertThatThrownBy(() -> serverWithoutResources.addResourceTemplate(specification)) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Server must be configured with resource capabilities"); + } + + @Test + void testRemoveResourceTemplate() { + McpSchema.ResourceTemplate template = McpSchema.ResourceTemplate.builder() + .uriTemplate("test://template/{id}") + .name("test-template") + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.SyncResourceTemplateSpecification specification = new McpServerFeatures.SyncResourceTemplateSpecification( + template, (exchange, req) -> new ReadResourceResult(List.of())); + + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .resourceTemplates(specification) + .build(); + + assertThatCode(() -> mcpSyncServer.removeResourceTemplate("test://template/{id}")).doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testRemoveResourceTemplateWithoutCapability() { + // Create a server without resource capabilities + var serverWithoutResources = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + assertThatThrownBy(() -> serverWithoutResources.removeResourceTemplate("test://template/{id}")) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Server must be configured with resource capabilities"); + } + + @Test + void testRemoveNonexistentResourceTemplate() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + assertThatCode(() -> mcpSyncServer.removeResourceTemplate("nonexistent://template/{id}")) + .doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testListResourceTemplates() { + McpSchema.ResourceTemplate template = McpSchema.ResourceTemplate.builder() + .uriTemplate("test://template/{id}") + .name("test-template") + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.SyncResourceTemplateSpecification specification = new McpServerFeatures.SyncResourceTemplateSpecification( + template, (exchange, req) -> new ReadResourceResult(List.of())); + + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .resourceTemplates(specification) + .build(); + + List templates = mcpSyncServer.listResourceTemplates(); + + assertThat(templates).isNotNull(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Prompts Tests + // --------------------------------------- + + @Test + void testNotifyPromptsListChanged() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + assertThatCode(mcpSyncServer::notifyPromptsListChanged).doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testAddPromptWithNullSpecification() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().prompts(false).build()) + .build(); + + assertThatThrownBy(() -> mcpSyncServer.addPrompt((McpServerFeatures.SyncPromptSpecification) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Prompt specification must not be null"); + } + + @Test + void testAddPromptWithoutCapability() { + var serverWithoutPrompts = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", "Test Prompt", List.of()); + McpServerFeatures.SyncPromptSpecification specification = new McpServerFeatures.SyncPromptSpecification(prompt, + (exchange, req) -> new GetPromptResult("Test prompt description", List + .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content"))))); + + assertThatThrownBy(() -> serverWithoutPrompts.addPrompt(specification)) + .isInstanceOf(IllegalStateException.class) + .hasMessage("Server must be configured with prompt capabilities"); + } + + @Test + void testRemovePromptWithoutCapability() { + var serverWithoutPrompts = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + assertThatThrownBy(() -> serverWithoutPrompts.removePrompt(TEST_PROMPT_NAME)) + .isInstanceOf(IllegalStateException.class) + .hasMessage("Server must be configured with prompt capabilities"); + } + + @Test + void testRemovePrompt() { + Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", "Test Prompt", List.of()); + McpServerFeatures.SyncPromptSpecification specification = new McpServerFeatures.SyncPromptSpecification(prompt, + (exchange, req) -> new GetPromptResult("Test prompt description", List + .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content"))))); + + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().prompts(true).build()) + .prompts(specification) + .build(); + + assertThatCode(() -> mcpSyncServer.removePrompt(TEST_PROMPT_NAME)).doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testRemoveNonexistentPrompt() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().prompts(true).build()) + .build(); + + assertThatCode(() -> mcpSyncServer.removePrompt("nonexistent://template/{id}")).doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Roots Tests + // --------------------------------------- + + @Test + void testRootsChangeHandlers() { + // Test with single consumer + var rootsReceived = new McpSchema.Root[1]; + var consumerCalled = new boolean[1]; + + var singleConsumerServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .rootsChangeHandlers(List.of((exchange, roots) -> { + consumerCalled[0] = true; + if (!roots.isEmpty()) { + rootsReceived[0] = roots.get(0); + } + })) + .build(); + assertThat(singleConsumerServer).isNotNull(); + assertThatCode(singleConsumerServer::closeGracefully).doesNotThrowAnyException(); + onClose(); + + // Test with multiple consumers + var consumer1Called = new boolean[1]; + var consumer2Called = new boolean[1]; + var rootsContent = new List[1]; + + var multipleConsumersServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .rootsChangeHandlers(List.of((exchange, roots) -> { + consumer1Called[0] = true; + rootsContent[0] = roots; + }, (exchange, roots) -> consumer2Called[0] = true)) + .build(); + + assertThat(multipleConsumersServer).isNotNull(); + assertThatCode(multipleConsumersServer::closeGracefully).doesNotThrowAnyException(); + onClose(); + + // Test error handling + var errorHandlingServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .rootsChangeHandlers(List.of((exchange, roots) -> { + throw new RuntimeException("Test error"); + })) + .build(); + + assertThat(errorHandlingServer).isNotNull(); + assertThatCode(errorHandlingServer::closeGracefully).doesNotThrowAnyException(); + onClose(); + + // Test without consumers + var noConsumersServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + assertThat(noConsumersServer).isNotNull(); + assertThatCode(noConsumersServer::closeGracefully).doesNotThrowAnyException(); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/AsyncToolSpecificationBuilderTest.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/AsyncToolSpecificationBuilderTest.java new file mode 100644 index 000000000..62332fcdb --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/AsyncToolSpecificationBuilderTest.java @@ -0,0 +1,267 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import static io.modelcontextprotocol.util.ToolsUtils.EMPTY_JSON_SCHEMA; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.util.List; +import java.util.Map; + +import io.modelcontextprotocol.spec.McpSchema; +import org.junit.jupiter.api.Test; + +import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.TextContent; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +/** + * Tests for {@link McpServerFeatures.AsyncToolSpecification.Builder}. + * + * @author Christian Tzolov + */ +class AsyncToolSpecificationBuilderTest { + + @Test + void builderShouldCreateValidAsyncToolSpecification() { + + Tool tool = McpSchema.Tool.builder() + .name("test-tool") + .title("A test tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + + McpServerFeatures.AsyncToolSpecification specification = McpServerFeatures.AsyncToolSpecification.builder() + .tool(tool) + .callHandler((exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of(new TextContent("Test result"))).isError(false).build())) + .build(); + + assertThat(specification).isNotNull(); + assertThat(specification.tool()).isEqualTo(tool); + assertThat(specification.callHandler()).isNotNull(); + assertThat(specification.call()).isNull(); // deprecated field should be null + } + + @Test + void builderShouldThrowExceptionWhenToolIsNull() { + assertThatThrownBy(() -> McpServerFeatures.AsyncToolSpecification.builder() + .callHandler((exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) + .build()).isInstanceOf(IllegalArgumentException.class).hasMessage("Tool must not be null"); + } + + @Test + void builderShouldThrowExceptionWhenCallToolIsNull() { + Tool tool = McpSchema.Tool.builder() + .name("test-tool") + .title("A test tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + + assertThatThrownBy(() -> McpServerFeatures.AsyncToolSpecification.builder().tool(tool).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Call handler function must not be null"); + } + + @Test + void builderShouldAllowMethodChaining() { + Tool tool = McpSchema.Tool.builder() + .name("test-tool") + .title("A test tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + McpServerFeatures.AsyncToolSpecification.Builder builder = McpServerFeatures.AsyncToolSpecification.builder(); + + // Then - verify method chaining returns the same builder instance + assertThat(builder.tool(tool)).isSameAs(builder); + assertThat(builder.callHandler( + (exchange, request) -> Mono.just(CallToolResult.builder().content(List.of()).isError(false).build()))) + .isSameAs(builder); + } + + @Test + void builtSpecificationShouldExecuteCallToolCorrectly() { + Tool tool = McpSchema.Tool.builder() + .name("calculator") + .title("Simple calculator") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + String expectedResult = "42"; + + McpServerFeatures.AsyncToolSpecification specification = McpServerFeatures.AsyncToolSpecification.builder() + .tool(tool) + .callHandler((exchange, request) -> { + return Mono.just(CallToolResult.builder() + .content(List.of(new TextContent(expectedResult))) + .isError(false) + .build()); + }) + .build(); + + CallToolRequest request = new CallToolRequest("calculator", Map.of()); + Mono resultMono = specification.callHandler().apply(null, request); + + StepVerifier.create(resultMono).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo(expectedResult); + assertThat(result.isError()).isFalse(); + }).verifyComplete(); + } + + @Test + @SuppressWarnings("deprecation") + void deprecatedConstructorShouldWorkCorrectly() { + Tool tool = McpSchema.Tool.builder() + .name("deprecated-tool") + .title("A deprecated tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + String expectedResult = "deprecated result"; + + // Test the deprecated constructor that takes a 'call' function + McpServerFeatures.AsyncToolSpecification specification = new McpServerFeatures.AsyncToolSpecification(tool, + (exchange, + arguments) -> Mono.just(CallToolResult.builder() + .content(List.of(new TextContent(expectedResult))) + .isError(false) + .build())); + + assertThat(specification).isNotNull(); + assertThat(specification.tool()).isEqualTo(tool); + assertThat(specification.call()).isNotNull(); // deprecated field should be set + assertThat(specification.callHandler()).isNotNull(); // should be automatically + // created + + // Test that the callTool function works (it should delegate to the call function) + CallToolRequest request = new CallToolRequest("deprecated-tool", Map.of("arg1", "value1")); + Mono resultMono = specification.callHandler().apply(null, request); + + StepVerifier.create(resultMono).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo(expectedResult); + assertThat(result.isError()).isFalse(); + }).verifyComplete(); + + // Test that the deprecated call function also works directly + Mono callResultMono = specification.call().apply(null, request.arguments()); + + StepVerifier.create(callResultMono).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo(expectedResult); + assertThat(result.isError()).isFalse(); + }).verifyComplete(); + } + + @Test + void fromSyncShouldConvertSyncToolSpecificationCorrectly() { + Tool tool = McpSchema.Tool.builder() + .name("sync-tool") + .title("A sync tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + String expectedResult = "sync result"; + + // Create a sync tool specification + McpServerFeatures.SyncToolSpecification syncSpec = McpServerFeatures.SyncToolSpecification.builder() + .tool(tool) + .callHandler((exchange, request) -> CallToolResult.builder() + .content(List.of(new TextContent(expectedResult))) + .isError(false) + .build()) + .build(); + + // Convert to async using fromSync + McpServerFeatures.AsyncToolSpecification asyncSpec = McpServerFeatures.AsyncToolSpecification + .fromSync(syncSpec); + + assertThat(asyncSpec).isNotNull(); + assertThat(asyncSpec.tool()).isEqualTo(tool); + assertThat(asyncSpec.callHandler()).isNotNull(); + assertThat(asyncSpec.call()).isNull(); // should be null since sync spec doesn't + // have deprecated call + + // Test that the converted async specification works correctly + CallToolRequest request = new CallToolRequest("sync-tool", Map.of("param", "value")); + Mono resultMono = asyncSpec.callHandler().apply(null, request); + + StepVerifier.create(resultMono).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo(expectedResult); + assertThat(result.isError()).isFalse(); + }).verifyComplete(); + } + + @Test + @SuppressWarnings("deprecation") + void fromSyncShouldConvertSyncToolSpecificationWithDeprecatedCallCorrectly() { + Tool tool = McpSchema.Tool.builder() + .name("sync-deprecated-tool") + .title("A sync tool with deprecated call") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + String expectedResult = "sync deprecated result"; + McpAsyncServerExchange nullExchange = null; // Mock or create a suitable exchange + // if needed + + // Create a sync tool specification using the deprecated constructor + McpServerFeatures.SyncToolSpecification syncSpec = new McpServerFeatures.SyncToolSpecification(tool, + (exchange, arguments) -> CallToolResult.builder() + .content(List.of(new TextContent(expectedResult))) + .isError(false) + .build()); + + // Convert to async using fromSync + McpServerFeatures.AsyncToolSpecification asyncSpec = McpServerFeatures.AsyncToolSpecification + .fromSync(syncSpec); + + assertThat(asyncSpec).isNotNull(); + assertThat(asyncSpec.tool()).isEqualTo(tool); + assertThat(asyncSpec.callHandler()).isNotNull(); + assertThat(asyncSpec.call()).isNotNull(); // should be set since sync spec has + // deprecated call + + // Test that the converted async specification works correctly via callTool + CallToolRequest request = new CallToolRequest("sync-deprecated-tool", Map.of("param", "value")); + Mono resultMono = asyncSpec.callHandler().apply(nullExchange, request); + + StepVerifier.create(resultMono).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo(expectedResult); + assertThat(result.isError()).isFalse(); + }).verifyComplete(); + + // Test that the deprecated call function also works + Mono callResultMono = asyncSpec.call().apply(nullExchange, request.arguments()); + + StepVerifier.create(callResultMono).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo(expectedResult); + assertThat(result.isError()).isFalse(); + }).verifyComplete(); + } + + @Test + void fromSyncShouldReturnNullWhenSyncSpecIsNull() { + assertThat(McpServerFeatures.AsyncToolSpecification.fromSync(null)).isNull(); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/HttpServletSseIntegrationTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/HttpServletSseIntegrationTests.java new file mode 100644 index 000000000..d2b9d14d0 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/HttpServletSseIntegrationTests.java @@ -0,0 +1,104 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import java.time.Duration; +import java.util.Map; +import java.util.stream.Stream; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.server.McpServer.AsyncSpecification; +import io.modelcontextprotocol.server.McpServer.SyncSpecification; +import io.modelcontextprotocol.server.transport.HttpServletSseServerTransportProvider; +import io.modelcontextprotocol.server.transport.TomcatTestUtil; +import jakarta.servlet.http.HttpServletRequest; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.apache.catalina.startup.Tomcat; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.provider.Arguments; + +import static org.assertj.core.api.Assertions.assertThat; + +@Timeout(15) +class HttpServletSseIntegrationTests extends AbstractMcpClientServerIntegrationTests { + + private static final int PORT = TomcatTestUtil.findAvailablePort(); + + private static final String CUSTOM_SSE_ENDPOINT = "/somePath/sse"; + + private static final String CUSTOM_MESSAGE_ENDPOINT = "/otherPath/mcp/message"; + + private HttpServletSseServerTransportProvider mcpServerTransportProvider; + + private Tomcat tomcat; + + static Stream clientsForTesting() { + return Stream.of(Arguments.of("httpclient")); + } + + @BeforeEach + public void before() { + // Create and configure the transport provider + mcpServerTransportProvider = HttpServletSseServerTransportProvider.builder() + .contextExtractor(TEST_CONTEXT_EXTRACTOR) + .messageEndpoint(CUSTOM_MESSAGE_ENDPOINT) + .sseEndpoint(CUSTOM_SSE_ENDPOINT) + .build(); + + tomcat = TomcatTestUtil.createTomcatServer("", PORT, mcpServerTransportProvider); + try { + tomcat.start(); + assertThat(tomcat.getServer().getState()).isEqualTo(LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + + clientBuilders + .put("httpclient", + McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT) + .sseEndpoint(CUSTOM_SSE_ENDPOINT) + .build()).requestTimeout(Duration.ofHours(10))); + } + + @Override + protected AsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(this.mcpServerTransportProvider); + } + + @Override + protected SyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(this.mcpServerTransportProvider); + } + + @AfterEach + public void after() { + if (mcpServerTransportProvider != null) { + mcpServerTransportProvider.closeGracefully().block(); + } + if (tomcat != null) { + try { + tomcat.stop(); + tomcat.destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + + @Override + protected void prepareClients(int port, String mcpEndpoint) { + } + + static McpTransportContextExtractor TEST_CONTEXT_EXTRACTOR = (r) -> McpTransportContext + .create(Map.of("important", "value")); + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/HttpServletStatelessIntegrationTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/HttpServletStatelessIntegrationTests.java new file mode 100644 index 000000000..491c2d4ed --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/HttpServletStatelessIntegrationTests.java @@ -0,0 +1,653 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiFunction; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.server.transport.HttpServletStatelessServerTransport; +import io.modelcontextprotocol.server.transport.TomcatTestUtil; +import io.modelcontextprotocol.spec.HttpHeaders; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.CompleteRequest; +import io.modelcontextprotocol.spec.McpSchema.CompleteResult; +import io.modelcontextprotocol.spec.McpSchema.ErrorCodes; +import io.modelcontextprotocol.spec.McpSchema.InitializeResult; +import io.modelcontextprotocol.spec.McpSchema.Prompt; +import io.modelcontextprotocol.spec.McpSchema.PromptArgument; +import io.modelcontextprotocol.spec.McpSchema.PromptReference; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.TextContent; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.spec.ProtocolVersions; +import net.javacrumbs.jsonunit.core.Option; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.apache.catalina.startup.Tomcat; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.web.client.RestClient; + +import static io.modelcontextprotocol.server.transport.HttpServletStatelessServerTransport.APPLICATION_JSON; +import static io.modelcontextprotocol.server.transport.HttpServletStatelessServerTransport.TEXT_EVENT_STREAM; +import static io.modelcontextprotocol.util.McpJsonMapperUtils.JSON_MAPPER; +import static io.modelcontextprotocol.util.ToolsUtils.EMPTY_JSON_SCHEMA; +import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson; +import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json; +import static org.assertj.core.api.Assertions.assertThat; +import static org.awaitility.Awaitility.await; + +@Timeout(15) +class HttpServletStatelessIntegrationTests { + + private static final int PORT = TomcatTestUtil.findAvailablePort(); + + private static final String CUSTOM_MESSAGE_ENDPOINT = "/otherPath/mcp/message"; + + private HttpServletStatelessServerTransport mcpStatelessServerTransport; + + ConcurrentHashMap clientBuilders = new ConcurrentHashMap<>(); + + private Tomcat tomcat; + + @BeforeEach + public void before() { + this.mcpStatelessServerTransport = HttpServletStatelessServerTransport.builder() + .messageEndpoint(CUSTOM_MESSAGE_ENDPOINT) + .build(); + + tomcat = TomcatTestUtil.createTomcatServer("", PORT, mcpStatelessServerTransport); + try { + tomcat.start(); + assertThat(tomcat.getServer().getState()).isEqualTo(LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + + clientBuilders + .put("httpclient", + McpClient.sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT) + .endpoint(CUSTOM_MESSAGE_ENDPOINT) + .build()).initializationTimeout(Duration.ofHours(10)).requestTimeout(Duration.ofHours(10))); + } + + @AfterEach + public void after() { + if (mcpStatelessServerTransport != null) { + mcpStatelessServerTransport.closeGracefully().block(); + } + if (tomcat != null) { + try { + tomcat.stop(); + tomcat.destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + + // --------------------------------------- + // Tools Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testToolCallSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var callResponse = CallToolResult.builder() + .content(List.of(new McpSchema.TextContent("CALL RESPONSE"))) + .isError(false) + .build(); + McpStatelessServerFeatures.SyncToolSpecification tool1 = new McpStatelessServerFeatures.SyncToolSpecification( + Tool.builder().name("tool1").title("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build(), + (transportContext, request) -> { + // perform a blocking call to a remote service + String response = RestClient.create() + .get() + .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md") + .retrieve() + .body(String.class); + assertThat(response).isNotBlank(); + return callResponse; + }); + + var mcpServer = McpServer.sync(mcpStatelessServerTransport) + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + } + finally { + mcpServer.close(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testInitialize(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var mcpServer = McpServer.sync(mcpStatelessServerTransport).build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + } + finally { + mcpServer.close(); + } + } + + // --------------------------------------- + // Completion Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : Completion call") + @ValueSource(strings = { "httpclient" }) + void testCompletionShouldReturnExpectedSuggestions(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + var expectedValues = List.of("python", "pytorch", "pyside"); + var completionResponse = new CompleteResult(new CompleteResult.CompleteCompletion(expectedValues, 10, // total + true // hasMore + )); + + AtomicReference samplingRequest = new AtomicReference<>(); + BiFunction completionHandler = (transportContext, + request) -> { + samplingRequest.set(request); + return completionResponse; + }; + + var mcpServer = McpServer.sync(mcpStatelessServerTransport) + .capabilities(ServerCapabilities.builder().completions().build()) + .prompts(new McpStatelessServerFeatures.SyncPromptSpecification( + new Prompt("code_review", "Code review", "this is code review prompt", + List.of(new PromptArgument("language", "Language", "string", false))), + (transportContext, getPromptRequest) -> null)) + .completions(new McpStatelessServerFeatures.SyncCompletionSpecification( + new PromptReference(PromptReference.TYPE, "code_review", "Code review"), completionHandler)) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CompleteRequest request = new CompleteRequest( + new PromptReference(PromptReference.TYPE, "code_review", "Code review"), + new CompleteRequest.CompleteArgument("language", "py")); + + CompleteResult result = mcpClient.completeCompletion(request); + + assertThat(result).isNotNull(); + + assertThat(samplingRequest.get().argument().name()).isEqualTo("language"); + assertThat(samplingRequest.get().argument().value()).isEqualTo("py"); + assertThat(samplingRequest.get().ref().type()).isEqualTo(PromptReference.TYPE); + } + finally { + mcpServer.close(); + } + } + + // --------------------------------------- + // Tool Structured Output Schema Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testStructuredOutputValidationSuccess(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of( + "type", "object", "properties", Map.of("result", Map.of("type", "number"), "operation", + Map.of("type", "string"), "timestamp", Map.of("type", "string")), + "required", List.of("result", "operation")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + McpStatelessServerFeatures.SyncToolSpecification tool = new McpStatelessServerFeatures.SyncToolSpecification( + calculatorTool, (transportContext, request) -> { + String expression = (String) request.arguments().getOrDefault("expression", "2 + 3"); + double result = evaluateExpression(expression); + return CallToolResult.builder() + .structuredContent( + Map.of("result", result, "operation", expression, "timestamp", "2024-01-01T10:00:00Z")) + .build(); + }); + + var mcpServer = McpServer.sync(mcpStatelessServerTransport) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Verify tool is listed with output schema + var toolsList = mcpClient.listTools(); + assertThat(toolsList.tools()).hasSize(1); + assertThat(toolsList.tools().get(0).name()).isEqualTo("calculator"); + // Note: outputSchema might be null in sync server, but validation still works + + // Call tool with valid structured output + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isFalse(); + assertThat(response.content()).hasSize(1); + assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + + assertThatJson(((McpSchema.TextContent) response.content().get(0)).text()).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"result":5.0,"operation":"2 + 3","timestamp":"2024-01-01T10:00:00Z"}""")); + + assertThat(response.structuredContent()).isNotNull(); + assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"result":5.0,"operation":"2 + 3","timestamp":"2024-01-01T10:00:00Z"}""")); + } + finally { + mcpServer.close(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testStructuredOutputOfObjectArrayValidationSuccess(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema that returns an array of objects + Map outputSchema = Map + .of( // @formatter:off + "type", "array", + "items", Map.of( + "type", "object", + "properties", Map.of( + "name", Map.of("type", "string"), + "age", Map.of("type", "number")), + "required", List.of("name", "age"))); // @formatter:on + + Tool calculatorTool = Tool.builder() + .name("getMembers") + .description("Returns a list of members") + .outputSchema(outputSchema) + .build(); + + McpStatelessServerFeatures.SyncToolSpecification tool = McpStatelessServerFeatures.SyncToolSpecification + .builder() + .tool(calculatorTool) + .callHandler((exchange, request) -> { + return CallToolResult.builder() + .structuredContent(List.of(Map.of("name", "John", "age", 30), Map.of("name", "Peter", "age", 25))) + .build(); + }) + .build(); + + var mcpServer = McpServer.sync(mcpStatelessServerTransport) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + assertThat(mcpClient.initialize()).isNotNull(); + + // Call tool with valid structured output of type array + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("getMembers", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isFalse(); + + assertThat(response.structuredContent()).isNotNull(); + assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isArray() + .hasSize(2) + .containsExactlyInAnyOrder(json(""" + {"name":"John","age":30}"""), json(""" + {"name":"Peter","age":25}""")); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testStructuredOutputWithInHandlerError(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of( + "type", "object", "properties", Map.of("result", Map.of("type", "number"), "operation", + Map.of("type", "string"), "timestamp", Map.of("type", "string")), + "required", List.of("result", "operation")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + // Handler that returns an error result + McpStatelessServerFeatures.SyncToolSpecification tool = McpStatelessServerFeatures.SyncToolSpecification + .builder() + .tool(calculatorTool) + .callHandler((exchange, request) -> CallToolResult.builder() + .isError(true) + .content(List.of(new TextContent("Error calling tool: Simulated in-handler error"))) + .build()) + .build(); + + var mcpServer = McpServer.sync(mcpStatelessServerTransport) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Verify tool is listed with output schema + var toolsList = mcpClient.listTools(); + assertThat(toolsList.tools()).hasSize(1); + assertThat(toolsList.tools().get(0).name()).isEqualTo("calculator"); + // Note: outputSchema might be null in sync server, but validation still works + + // Call tool with valid structured output + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isTrue(); + assertThat(response.content()).isNotEmpty(); + assertThat(response.content()) + .containsExactly(new McpSchema.TextContent("Error calling tool: Simulated in-handler error")); + assertThat(response.structuredContent()).isNull(); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testStructuredOutputValidationFailure(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of("type", "object", "properties", + Map.of("result", Map.of("type", "number"), "operation", Map.of("type", "string")), "required", + List.of("result", "operation")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + McpStatelessServerFeatures.SyncToolSpecification tool = new McpStatelessServerFeatures.SyncToolSpecification( + calculatorTool, (transportContext, request) -> { + // Return invalid structured output. Result should be number, missing + // operation + return CallToolResult.builder() + .addTextContent("Invalid calculation") + .structuredContent(Map.of("result", "not-a-number", "extra", "field")) + .build(); + }); + + var mcpServer = McpServer.sync(mcpStatelessServerTransport) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Call tool with invalid structured output + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isTrue(); + assertThat(response.content()).hasSize(1); + assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + + String errorMessage = ((McpSchema.TextContent) response.content().get(0)).text(); + assertThat(errorMessage).contains("Validation failed"); + } + finally { + mcpServer.close(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testStructuredOutputMissingStructuredContent(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of("type", "object", "properties", + Map.of("result", Map.of("type", "number")), "required", List.of("result")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + McpStatelessServerFeatures.SyncToolSpecification tool = new McpStatelessServerFeatures.SyncToolSpecification( + calculatorTool, (transportContext, request) -> { + // Return result without structured content but tool has output schema + return CallToolResult.builder().addTextContent("Calculation completed").build(); + }); + + var mcpServer = McpServer.sync(mcpStatelessServerTransport) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .instructions("bla") + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Call tool that should return structured content but doesn't + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isTrue(); + assertThat(response.content()).hasSize(1); + assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + + String errorMessage = ((McpSchema.TextContent) response.content().get(0)).text(); + assertThat(errorMessage).isEqualTo( + "Response missing structured content which is expected when calling tool with non-empty outputSchema"); + } + finally { + mcpServer.close(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testStructuredOutputRuntimeToolAddition(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Start server without tools + var mcpServer = McpServer.sync(mcpStatelessServerTransport) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Initially no tools + assertThat(mcpClient.listTools().tools()).isEmpty(); + + // Add tool with output schema at runtime + Map outputSchema = Map.of("type", "object", "properties", + Map.of("message", Map.of("type", "string"), "count", Map.of("type", "integer")), "required", + List.of("message", "count")); + + Tool dynamicTool = Tool.builder() + .name("dynamic-tool") + .description("Dynamically added tool") + .outputSchema(outputSchema) + .build(); + + McpStatelessServerFeatures.SyncToolSpecification toolSpec = new McpStatelessServerFeatures.SyncToolSpecification( + dynamicTool, (transportContext, request) -> { + int count = (Integer) request.arguments().getOrDefault("count", 1); + return CallToolResult.builder() + .addTextContent("Dynamic tool executed " + count + " times") + .structuredContent(Map.of("message", "Dynamic execution", "count", count)) + .build(); + }); + + // Add tool to server + mcpServer.addTool(toolSpec); + + // Wait for tool list change notification + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(mcpClient.listTools().tools()).hasSize(1); + }); + + // Verify tool was added with output schema + var toolsList = mcpClient.listTools(); + assertThat(toolsList.tools()).hasSize(1); + assertThat(toolsList.tools().get(0).name()).isEqualTo("dynamic-tool"); + // Note: outputSchema might be null in sync server, but validation still works + + // Call dynamically added tool + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("dynamic-tool", Map.of("count", 3))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isFalse(); + assertThat(response.content()).hasSize(1); + assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) response.content().get(0)).text()) + .isEqualTo("Dynamic tool executed 3 times"); + + assertThat(response.structuredContent()).isNotNull(); + assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"count":3,"message":"Dynamic execution"}""")); + } + finally { + mcpServer.close(); + } + } + + @Test + void testThrownMcpErrorAndJsonRpcError() throws Exception { + var mcpServer = McpServer.sync(mcpStatelessServerTransport) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .build(); + + Tool testTool = Tool.builder().name("test").description("test").build(); + + McpStatelessServerFeatures.SyncToolSpecification toolSpec = new McpStatelessServerFeatures.SyncToolSpecification( + testTool, (transportContext, request) -> { + throw new RuntimeException("testing"); + }); + + mcpServer.addTool(toolSpec); + + McpSchema.CallToolRequest callToolRequest = new McpSchema.CallToolRequest("test", Map.of()); + McpSchema.JSONRPCRequest jsonrpcRequest = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, + McpSchema.METHOD_TOOLS_CALL, "test", callToolRequest); + + MockHttpServletRequest request = new MockHttpServletRequest("POST", CUSTOM_MESSAGE_ENDPOINT); + MockHttpServletResponse response = new MockHttpServletResponse(); + + byte[] content = JSON_MAPPER.writeValueAsBytes(jsonrpcRequest); + request.setContent(content); + request.addHeader("Content-Type", "application/json"); + request.addHeader("Content-Length", Integer.toString(content.length)); + request.addHeader("Content-Length", Integer.toString(content.length)); + request.addHeader("Accept", APPLICATION_JSON + ", " + TEXT_EVENT_STREAM); + request.addHeader("Content-Type", APPLICATION_JSON); + request.addHeader("Cache-Control", "no-cache"); + request.addHeader(HttpHeaders.PROTOCOL_VERSION, ProtocolVersions.MCP_2025_03_26); + + mcpStatelessServerTransport.service(request, response); + + McpSchema.JSONRPCResponse jsonrpcResponse = JSON_MAPPER.readValue(response.getContentAsByteArray(), + McpSchema.JSONRPCResponse.class); + + assertThat(jsonrpcResponse).isNotNull(); + assertThat(jsonrpcResponse.error()).isNotNull(); + assertThat(jsonrpcResponse.error().code()).isEqualTo(ErrorCodes.INTERNAL_ERROR); + assertThat(jsonrpcResponse.error().message()).isEqualTo("testing"); + + mcpServer.close(); + } + + private double evaluateExpression(String expression) { + // Simple expression evaluator for testing + return switch (expression) { + case "2 + 3" -> 5.0; + case "10 * 2" -> 20.0; + case "7 + 8" -> 15.0; + case "5 + 3" -> 8.0; + default -> 0.0; + }; + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableAsyncServerTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableAsyncServerTests.java new file mode 100644 index 000000000..96f1524b7 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableAsyncServerTests.java @@ -0,0 +1,30 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import org.junit.jupiter.api.Timeout; + +import io.modelcontextprotocol.server.transport.HttpServletStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; + +/** + * Tests for {@link McpAsyncServer} using + * {@link HttpServletStreamableServerTransportProvider}. + * + * @author Christian Tzolov + */ +@Timeout(15) +class HttpServletStreamableAsyncServerTests extends AbstractMcpAsyncServerTests { + + protected McpStreamableServerTransportProvider createMcpTransportProvider() { + return HttpServletStreamableServerTransportProvider.builder().mcpEndpoint("/mcp/message").build(); + } + + @Override + protected McpServer.AsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(createMcpTransportProvider()); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableIntegrationTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableIntegrationTests.java new file mode 100644 index 000000000..81423e0c5 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableIntegrationTests.java @@ -0,0 +1,102 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import java.time.Duration; +import java.util.Map; +import java.util.stream.Stream; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.server.McpServer.AsyncSpecification; +import io.modelcontextprotocol.server.McpServer.SyncSpecification; +import io.modelcontextprotocol.server.transport.HttpServletStreamableServerTransportProvider; +import io.modelcontextprotocol.server.transport.TomcatTestUtil; +import jakarta.servlet.http.HttpServletRequest; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.apache.catalina.startup.Tomcat; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.provider.Arguments; + +import static org.assertj.core.api.Assertions.assertThat; + +@Timeout(15) +class HttpServletStreamableIntegrationTests extends AbstractMcpClientServerIntegrationTests { + + private static final int PORT = TomcatTestUtil.findAvailablePort(); + + private static final String MESSAGE_ENDPOINT = "/mcp/message"; + + private HttpServletStreamableServerTransportProvider mcpServerTransportProvider; + + private Tomcat tomcat; + + static Stream clientsForTesting() { + return Stream.of(Arguments.of("httpclient")); + } + + @BeforeEach + public void before() { + // Create and configure the transport provider + mcpServerTransportProvider = HttpServletStreamableServerTransportProvider.builder() + .contextExtractor(TEST_CONTEXT_EXTRACTOR) + .mcpEndpoint(MESSAGE_ENDPOINT) + .keepAliveInterval(Duration.ofSeconds(1)) + .build(); + + tomcat = TomcatTestUtil.createTomcatServer("", PORT, mcpServerTransportProvider); + try { + tomcat.start(); + assertThat(tomcat.getServer().getState()).isEqualTo(LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + + clientBuilders + .put("httpclient", + McpClient.sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT) + .endpoint(MESSAGE_ENDPOINT) + .build()).requestTimeout(Duration.ofHours(10))); + } + + @Override + protected AsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(this.mcpServerTransportProvider); + } + + @Override + protected SyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(this.mcpServerTransportProvider); + } + + @AfterEach + public void after() { + if (mcpServerTransportProvider != null) { + mcpServerTransportProvider.closeGracefully().block(); + } + if (tomcat != null) { + try { + tomcat.stop(); + tomcat.destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + + @Override + protected void prepareClients(int port, String mcpEndpoint) { + } + + static McpTransportContextExtractor TEST_CONTEXT_EXTRACTOR = (r) -> McpTransportContext + .create(Map.of("important", "value")); + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableSyncServerTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableSyncServerTests.java new file mode 100644 index 000000000..87c0712dc --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableSyncServerTests.java @@ -0,0 +1,30 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import org.junit.jupiter.api.Timeout; + +import io.modelcontextprotocol.server.transport.HttpServletStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; + +/** + * Tests for {@link McpSyncServer} using + * {@link HttpServletStreamableServerTransportProvider}. + * + * @author Christian Tzolov + */ +@Timeout(15) +class HttpServletStreamableSyncServerTests extends AbstractMcpSyncServerTests { + + protected McpStreamableServerTransportProvider createMcpTransportProvider() { + return HttpServletStreamableServerTransportProvider.builder().mcpEndpoint("/mcp/message").build(); + } + + @Override + protected McpServer.SyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(createMcpTransportProvider()); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/McpAsyncServerExchangeTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/McpAsyncServerExchangeTests.java new file mode 100644 index 000000000..640d34c9c --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/McpAsyncServerExchangeTests.java @@ -0,0 +1,698 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import io.modelcontextprotocol.common.McpTransportContext; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.json.TypeRef; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * Tests for {@link McpAsyncServerExchange}. + * + * @author Christian Tzolov + */ +class McpAsyncServerExchangeTests { + + @Mock + private McpServerSession mockSession; + + private McpSchema.ClientCapabilities clientCapabilities; + + private McpSchema.Implementation clientInfo; + + private McpAsyncServerExchange exchange; + + @BeforeEach + void setUp() { + MockitoAnnotations.openMocks(this); + + clientCapabilities = McpSchema.ClientCapabilities.builder().roots(true).build(); + + clientInfo = new McpSchema.Implementation("test-client", "1.0.0"); + + exchange = new McpAsyncServerExchange("testSessionId", mockSession, clientCapabilities, clientInfo, + McpTransportContext.EMPTY); + } + + @Test + void testListRootsWithSinglePage() { + + List roots = Arrays.asList(new McpSchema.Root("file:///home/user/project1", "Project 1"), + new McpSchema.Root("file:///home/user/project2", "Project 2")); + McpSchema.ListRootsResult singlePageResult = new McpSchema.ListRootsResult(roots, null); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), any(McpSchema.PaginatedRequest.class), + any(TypeRef.class))) + .thenReturn(Mono.just(singlePageResult)); + + StepVerifier.create(exchange.listRoots()).assertNext(result -> { + assertThat(result.roots()).hasSize(2); + assertThat(result.roots().get(0).uri()).isEqualTo("file:///home/user/project1"); + assertThat(result.roots().get(0).name()).isEqualTo("Project 1"); + assertThat(result.roots().get(1).uri()).isEqualTo("file:///home/user/project2"); + assertThat(result.roots().get(1).name()).isEqualTo("Project 2"); + assertThat(result.nextCursor()).isNull(); + + // Verify that the returned list is unmodifiable + assertThatThrownBy(() -> result.roots().add(new McpSchema.Root("file:///test", "Test"))) + .isInstanceOf(UnsupportedOperationException.class); + }).verifyComplete(); + } + + @Test + void testListRootsWithMultiplePages() { + + List page1Roots = Arrays.asList(new McpSchema.Root("file:///home/user/project1", "Project 1"), + new McpSchema.Root("file:///home/user/project2", "Project 2")); + List page2Roots = Arrays.asList(new McpSchema.Root("file:///home/user/project3", "Project 3")); + + McpSchema.ListRootsResult page1Result = new McpSchema.ListRootsResult(page1Roots, "cursor1"); + McpSchema.ListRootsResult page2Result = new McpSchema.ListRootsResult(page2Roots, null); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), eq(new McpSchema.PaginatedRequest(null)), + any(TypeRef.class))) + .thenReturn(Mono.just(page1Result)); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), eq(new McpSchema.PaginatedRequest("cursor1")), + any(TypeRef.class))) + .thenReturn(Mono.just(page2Result)); + + StepVerifier.create(exchange.listRoots()).assertNext(result -> { + assertThat(result.roots()).hasSize(3); + assertThat(result.roots().get(0).uri()).isEqualTo("file:///home/user/project1"); + assertThat(result.roots().get(1).uri()).isEqualTo("file:///home/user/project2"); + assertThat(result.roots().get(2).uri()).isEqualTo("file:///home/user/project3"); + assertThat(result.nextCursor()).isNull(); + + // Verify that the returned list is unmodifiable + assertThatThrownBy(() -> result.roots().add(new McpSchema.Root("file:///test", "Test"))) + .isInstanceOf(UnsupportedOperationException.class); + }).verifyComplete(); + } + + @Test + void testListRootsWithEmptyResult() { + + McpSchema.ListRootsResult emptyResult = new McpSchema.ListRootsResult(new ArrayList<>(), null); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), any(McpSchema.PaginatedRequest.class), + any(TypeRef.class))) + .thenReturn(Mono.just(emptyResult)); + + StepVerifier.create(exchange.listRoots()).assertNext(result -> { + assertThat(result.roots()).isEmpty(); + assertThat(result.nextCursor()).isNull(); + + // Verify that the returned list is unmodifiable + assertThatThrownBy(() -> result.roots().add(new McpSchema.Root("file:///test", "Test"))) + .isInstanceOf(UnsupportedOperationException.class); + }).verifyComplete(); + } + + @Test + void testListRootsWithSpecificCursor() { + + List roots = Arrays.asList(new McpSchema.Root("file:///home/user/project3", "Project 3")); + McpSchema.ListRootsResult result = new McpSchema.ListRootsResult(roots, "nextCursor"); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), eq(new McpSchema.PaginatedRequest("someCursor")), + any(TypeRef.class))) + .thenReturn(Mono.just(result)); + + StepVerifier.create(exchange.listRoots("someCursor")).assertNext(listResult -> { + assertThat(listResult.roots()).hasSize(1); + assertThat(listResult.roots().get(0).uri()).isEqualTo("file:///home/user/project3"); + assertThat(listResult.nextCursor()).isEqualTo("nextCursor"); + }).verifyComplete(); + } + + @Test + void testListRootsWithError() { + + when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), any(McpSchema.PaginatedRequest.class), + any(TypeRef.class))) + .thenReturn(Mono.error(new RuntimeException("Network error"))); + + // When & Then + StepVerifier.create(exchange.listRoots()).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(RuntimeException.class).hasMessage("Network error"); + }); + } + + @Test + void testListRootsUnmodifiabilityAfterAccumulation() { + + List page1Roots = new ArrayList<>( + Arrays.asList(new McpSchema.Root("file:///home/user/project1", "Project 1"))); + List page2Roots = new ArrayList<>( + Arrays.asList(new McpSchema.Root("file:///home/user/project2", "Project 2"))); + + McpSchema.ListRootsResult page1Result = new McpSchema.ListRootsResult(page1Roots, "cursor1"); + McpSchema.ListRootsResult page2Result = new McpSchema.ListRootsResult(page2Roots, null); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), eq(new McpSchema.PaginatedRequest(null)), + any(TypeRef.class))) + .thenReturn(Mono.just(page1Result)); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), eq(new McpSchema.PaginatedRequest("cursor1")), + any(TypeRef.class))) + .thenReturn(Mono.just(page2Result)); + + StepVerifier.create(exchange.listRoots()).assertNext(result -> { + // Verify the accumulated result is correct + assertThat(result.roots()).hasSize(2); + + // Verify that the returned list is unmodifiable + assertThatThrownBy(() -> result.roots().add(new McpSchema.Root("file:///test", "Test"))) + .isInstanceOf(UnsupportedOperationException.class); + + // Verify that clear() also throws UnsupportedOperationException + assertThatThrownBy(() -> result.roots().clear()).isInstanceOf(UnsupportedOperationException.class); + + // Verify that remove() also throws UnsupportedOperationException + assertThatThrownBy(() -> result.roots().remove(0)).isInstanceOf(UnsupportedOperationException.class); + }).verifyComplete(); + } + + @Test + void testGetClientCapabilities() { + assertThat(exchange.getClientCapabilities()).isEqualTo(clientCapabilities); + } + + @Test + void testGetClientInfo() { + assertThat(exchange.getClientInfo()).isEqualTo(clientInfo); + } + + // --------------------------------------- + // Logging Notification Tests + // --------------------------------------- + + @Test + void testLoggingNotificationWithNullMessage() { + StepVerifier.create(exchange.loggingNotification(null)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class).hasMessage("Logging message must not be null"); + }); + } + + @Test + void testSetMinLoggingLevelWithNullValue() { + assertThatThrownBy(() -> exchange.setMinLoggingLevel(null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("minLoggingLevel must not be null"); + } + + @Test + void testLoggingNotificationWithAllowedLevel() { + McpSchema.LoggingMessageNotification notification = McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.ERROR) + .logger("test-logger") + .data("Test error message") + .build(); + + when(mockSession.isNotificationForLevelAllowed(any())).thenReturn(Boolean.TRUE); + when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(notification))) + .thenReturn(Mono.empty()); + + StepVerifier.create(exchange.loggingNotification(notification)).verifyComplete(); + + verify(mockSession, times(1)).isNotificationForLevelAllowed(eq(McpSchema.LoggingLevel.ERROR)); + verify(mockSession, times(1)).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(notification)); + } + + @Test + void testLoggingNotificationWithFilteredLevel() { + exchange.setMinLoggingLevel(McpSchema.LoggingLevel.DEBUG); + verify(mockSession, times(1)).setMinLoggingLevel(eq(McpSchema.LoggingLevel.DEBUG)); + + McpSchema.LoggingMessageNotification debugNotification = McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.DEBUG) + .logger("test-logger") + .data("Debug message that should be filtered") + .build(); + + when(mockSession.isNotificationForLevelAllowed(eq(McpSchema.LoggingLevel.DEBUG))).thenReturn(Boolean.TRUE); + when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(debugNotification))) + .thenReturn(Mono.empty()); + + StepVerifier.create(exchange.loggingNotification(debugNotification)).verifyComplete(); + + verify(mockSession, times(1)).isNotificationForLevelAllowed(eq(McpSchema.LoggingLevel.DEBUG)); + verify(mockSession, times(1)).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), + eq(debugNotification)); + + McpSchema.LoggingMessageNotification warningNotification = McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.WARNING) + .logger("test-logger") + .data("Debug message that should be filtered") + .build(); + + StepVerifier.create(exchange.loggingNotification(warningNotification)).verifyComplete(); + + verify(mockSession, times(1)).isNotificationForLevelAllowed(eq(McpSchema.LoggingLevel.WARNING)); + verify(mockSession, never()).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), + eq(warningNotification)); + } + + @Test + void testLoggingNotificationWithSessionError() { + McpSchema.LoggingMessageNotification notification = McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.ERROR) + .logger("test-logger") + .data("Test error message") + .build(); + + when(mockSession.isNotificationForLevelAllowed(any())).thenReturn(Boolean.TRUE); + when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(notification))) + .thenReturn(Mono.error(new RuntimeException("Session error"))); + + StepVerifier.create(exchange.loggingNotification(notification)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(RuntimeException.class).hasMessage("Session error"); + }); + } + + // --------------------------------------- + // Create Elicitation Tests + // --------------------------------------- + + @Test + void testCreateElicitationWithNullCapabilities() { + // Given - Create exchange with null capabilities + McpAsyncServerExchange exchangeWithNullCapabilities = new McpAsyncServerExchange(mockSession, null, clientInfo); + + McpSchema.ElicitRequest elicitRequest = McpSchema.ElicitRequest.builder() + .message("Please provide your name") + .build(); + + StepVerifier.create(exchangeWithNullCapabilities.createElicitation(elicitRequest)) + .verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class) + .hasMessage("Client must be initialized. Call the initialize method first!"); + }); + + // Verify that sendRequest was never called due to null capabilities + verify(mockSession, never()).sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), any(), any(TypeRef.class)); + } + + @Test + void testCreateElicitationWithoutElicitationCapabilities() { + // Given - Create exchange without elicitation capabilities + McpSchema.ClientCapabilities capabilitiesWithoutElicitation = McpSchema.ClientCapabilities.builder() + .roots(true) + .build(); + + McpAsyncServerExchange exchangeWithoutElicitation = new McpAsyncServerExchange(mockSession, + capabilitiesWithoutElicitation, clientInfo); + + McpSchema.ElicitRequest elicitRequest = McpSchema.ElicitRequest.builder() + .message("Please provide your name") + .build(); + + StepVerifier.create(exchangeWithoutElicitation.createElicitation(elicitRequest)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class) + .hasMessage("Client must be configured with elicitation capabilities"); + }); + + // Verify that sendRequest was never called due to missing elicitation + // capabilities + verify(mockSession, never()).sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), any(), any(TypeRef.class)); + } + + @Test + void testCreateElicitationWithComplexRequest() { + + McpSchema.ClientCapabilities capabilitiesWithElicitation = McpSchema.ClientCapabilities.builder() + .elicitation() + .build(); + + McpAsyncServerExchange exchangeWithElicitation = new McpAsyncServerExchange(mockSession, + capabilitiesWithElicitation, clientInfo); + + // Create a complex elicit request with schema + java.util.Map requestedSchema = new java.util.HashMap<>(); + requestedSchema.put("type", "object"); + requestedSchema.put("properties", java.util.Map.of("name", java.util.Map.of("type", "string"), "age", + java.util.Map.of("type", "number"))); + requestedSchema.put("required", java.util.List.of("name")); + + McpSchema.ElicitRequest elicitRequest = McpSchema.ElicitRequest.builder() + .message("Please provide your personal information") + .requestedSchema(requestedSchema) + .build(); + + java.util.Map responseContent = new java.util.HashMap<>(); + responseContent.put("name", "John Doe"); + responseContent.put("age", 30); + + McpSchema.ElicitResult expectedResult = McpSchema.ElicitResult.builder() + .message(McpSchema.ElicitResult.Action.ACCEPT) + .content(responseContent) + .build(); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), eq(elicitRequest), any(TypeRef.class))) + .thenReturn(Mono.just(expectedResult)); + + StepVerifier.create(exchangeWithElicitation.createElicitation(elicitRequest)).assertNext(result -> { + assertThat(result).isEqualTo(expectedResult); + assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); + assertThat(result.content()).isNotNull(); + assertThat(result.content().get("name")).isEqualTo("John Doe"); + assertThat(result.content().get("age")).isEqualTo(30); + }).verifyComplete(); + } + + @Test + void testCreateElicitationWithDeclineAction() { + + McpSchema.ClientCapabilities capabilitiesWithElicitation = McpSchema.ClientCapabilities.builder() + .elicitation() + .build(); + + McpAsyncServerExchange exchangeWithElicitation = new McpAsyncServerExchange(mockSession, + capabilitiesWithElicitation, clientInfo); + + McpSchema.ElicitRequest elicitRequest = McpSchema.ElicitRequest.builder() + .message("Please provide sensitive information") + .build(); + + McpSchema.ElicitResult expectedResult = McpSchema.ElicitResult.builder() + .message(McpSchema.ElicitResult.Action.DECLINE) + .build(); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), eq(elicitRequest), any(TypeRef.class))) + .thenReturn(Mono.just(expectedResult)); + + StepVerifier.create(exchangeWithElicitation.createElicitation(elicitRequest)).assertNext(result -> { + assertThat(result).isEqualTo(expectedResult); + assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.DECLINE); + }).verifyComplete(); + } + + @Test + void testCreateElicitationWithCancelAction() { + + McpSchema.ClientCapabilities capabilitiesWithElicitation = McpSchema.ClientCapabilities.builder() + .elicitation() + .build(); + + McpAsyncServerExchange exchangeWithElicitation = new McpAsyncServerExchange(mockSession, + capabilitiesWithElicitation, clientInfo); + + McpSchema.ElicitRequest elicitRequest = McpSchema.ElicitRequest.builder() + .message("Please provide your information") + .build(); + + McpSchema.ElicitResult expectedResult = McpSchema.ElicitResult.builder() + .message(McpSchema.ElicitResult.Action.CANCEL) + .build(); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), eq(elicitRequest), any(TypeRef.class))) + .thenReturn(Mono.just(expectedResult)); + + StepVerifier.create(exchangeWithElicitation.createElicitation(elicitRequest)).assertNext(result -> { + assertThat(result).isEqualTo(expectedResult); + assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.CANCEL); + }).verifyComplete(); + } + + @Test + void testCreateElicitationWithSessionError() { + + McpSchema.ClientCapabilities capabilitiesWithElicitation = McpSchema.ClientCapabilities.builder() + .elicitation() + .build(); + + McpAsyncServerExchange exchangeWithElicitation = new McpAsyncServerExchange(mockSession, + capabilitiesWithElicitation, clientInfo); + + McpSchema.ElicitRequest elicitRequest = McpSchema.ElicitRequest.builder() + .message("Please provide your name") + .build(); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), eq(elicitRequest), any(TypeRef.class))) + .thenReturn(Mono.error(new RuntimeException("Session communication error"))); + + StepVerifier.create(exchangeWithElicitation.createElicitation(elicitRequest)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(RuntimeException.class).hasMessage("Session communication error"); + }); + } + + // --------------------------------------- + // Create Message Tests + // --------------------------------------- + + @Test + void testCreateMessageWithNullCapabilities() { + + McpAsyncServerExchange exchangeWithNullCapabilities = new McpAsyncServerExchange(mockSession, null, clientInfo); + + McpSchema.CreateMessageRequest createMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(Arrays + .asList(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Hello, world!")))) + .build(); + + StepVerifier.create(exchangeWithNullCapabilities.createMessage(createMessageRequest)) + .verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class) + .hasMessage("Client must be initialized. Call the initialize method first!"); + }); + + // Verify that sendRequest was never called due to null capabilities + verify(mockSession, never()).sendRequest(eq(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE), any(), + any(TypeRef.class)); + } + + @Test + void testCreateMessageWithoutSamplingCapabilities() { + + McpSchema.ClientCapabilities capabilitiesWithoutSampling = McpSchema.ClientCapabilities.builder() + .roots(true) + .build(); + + McpAsyncServerExchange exchangeWithoutSampling = new McpAsyncServerExchange(mockSession, + capabilitiesWithoutSampling, clientInfo); + + McpSchema.CreateMessageRequest createMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(Arrays + .asList(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Hello, world!")))) + .build(); + + StepVerifier.create(exchangeWithoutSampling.createMessage(createMessageRequest)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class) + .hasMessage("Client must be configured with sampling capabilities"); + }); + + // Verify that sendRequest was never called due to missing sampling capabilities + verify(mockSession, never()).sendRequest(eq(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE), any(), + any(TypeRef.class)); + } + + @Test + void testCreateMessageWithBasicRequest() { + + McpSchema.ClientCapabilities capabilitiesWithSampling = McpSchema.ClientCapabilities.builder() + .sampling() + .build(); + + McpAsyncServerExchange exchangeWithSampling = new McpAsyncServerExchange(mockSession, capabilitiesWithSampling, + clientInfo); + + McpSchema.CreateMessageRequest createMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(Arrays + .asList(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Hello, world!")))) + .build(); + + McpSchema.CreateMessageResult expectedResult = McpSchema.CreateMessageResult.builder() + .role(McpSchema.Role.ASSISTANT) + .content(new McpSchema.TextContent("Hello! How can I help you today?")) + .model("gpt-4") + .stopReason(McpSchema.CreateMessageResult.StopReason.END_TURN) + .build(); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE), eq(createMessageRequest), + any(TypeRef.class))) + .thenReturn(Mono.just(expectedResult)); + + StepVerifier.create(exchangeWithSampling.createMessage(createMessageRequest)).assertNext(result -> { + assertThat(result).isEqualTo(expectedResult); + assertThat(result.role()).isEqualTo(McpSchema.Role.ASSISTANT); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Hello! How can I help you today?"); + assertThat(result.model()).isEqualTo("gpt-4"); + assertThat(result.stopReason()).isEqualTo(McpSchema.CreateMessageResult.StopReason.END_TURN); + }).verifyComplete(); + } + + @Test + void testCreateMessageWithImageContent() { + + McpSchema.ClientCapabilities capabilitiesWithSampling = McpSchema.ClientCapabilities.builder() + .sampling() + .build(); + + McpAsyncServerExchange exchangeWithSampling = new McpAsyncServerExchange(mockSession, capabilitiesWithSampling, + clientInfo); + + // Create request with image content + McpSchema.CreateMessageRequest createMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(Arrays.asList(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.ImageContent(null, "data:image/jpeg;base64,/9j/4AAQSkZJRgABAQEAYABgAAD...", + "image/jpeg")))) + .build(); + + McpSchema.CreateMessageResult expectedResult = McpSchema.CreateMessageResult.builder() + .role(McpSchema.Role.ASSISTANT) + .content(new McpSchema.TextContent("I can see an image. It appears to be a photograph.")) + .model("gpt-4-vision") + .stopReason(McpSchema.CreateMessageResult.StopReason.END_TURN) + .build(); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE), eq(createMessageRequest), + any(TypeRef.class))) + .thenReturn(Mono.just(expectedResult)); + + StepVerifier.create(exchangeWithSampling.createMessage(createMessageRequest)).assertNext(result -> { + assertThat(result).isEqualTo(expectedResult); + assertThat(result.role()).isEqualTo(McpSchema.Role.ASSISTANT); + assertThat(result.model()).isEqualTo("gpt-4-vision"); + }).verifyComplete(); + } + + @Test + void testCreateMessageWithSessionError() { + + McpSchema.ClientCapabilities capabilitiesWithSampling = McpSchema.ClientCapabilities.builder() + .sampling() + .build(); + + McpAsyncServerExchange exchangeWithSampling = new McpAsyncServerExchange(mockSession, capabilitiesWithSampling, + clientInfo); + + McpSchema.CreateMessageRequest createMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(Arrays + .asList(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Hello")))) + .build(); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE), eq(createMessageRequest), + any(TypeRef.class))) + .thenReturn(Mono.error(new RuntimeException("Session communication error"))); + + StepVerifier.create(exchangeWithSampling.createMessage(createMessageRequest)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(RuntimeException.class).hasMessage("Session communication error"); + }); + } + + @Test + void testCreateMessageWithIncludeContext() { + + McpSchema.ClientCapabilities capabilitiesWithSampling = McpSchema.ClientCapabilities.builder() + .sampling() + .build(); + + McpAsyncServerExchange exchangeWithSampling = new McpAsyncServerExchange(mockSession, capabilitiesWithSampling, + clientInfo); + + McpSchema.CreateMessageRequest createMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(Arrays.asList(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("What files are available?")))) + .includeContext(McpSchema.CreateMessageRequest.ContextInclusionStrategy.ALL_SERVERS) + .build(); + + McpSchema.CreateMessageResult expectedResult = McpSchema.CreateMessageResult.builder() + .role(McpSchema.Role.ASSISTANT) + .content(new McpSchema.TextContent("Based on the available context, I can see several files...")) + .model("gpt-4") + .stopReason(McpSchema.CreateMessageResult.StopReason.END_TURN) + .build(); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE), eq(createMessageRequest), + any(TypeRef.class))) + .thenReturn(Mono.just(expectedResult)); + + StepVerifier.create(exchangeWithSampling.createMessage(createMessageRequest)).assertNext(result -> { + assertThat(result).isEqualTo(expectedResult); + assertThat(((McpSchema.TextContent) result.content()).text()).contains("context"); + }).verifyComplete(); + } + + // --------------------------------------- + // Ping Tests + // --------------------------------------- + + @Test + void testPingWithSuccessfulResponse() { + + java.util.Map expectedResponse = java.util.Map.of(); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeRef.class))) + .thenReturn(Mono.just(expectedResponse)); + + StepVerifier.create(exchange.ping()).assertNext(result -> { + assertThat(result).isEqualTo(expectedResponse); + assertThat(result).isInstanceOf(java.util.Map.class); + }).verifyComplete(); + + // Verify that sendRequest was called with correct parameters + verify(mockSession, times(1)).sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeRef.class)); + } + + @Test + void testPingWithMcpError() { + // Given - Mock an MCP-specific error during ping + McpError mcpError = new McpError("Server unavailable"); + when(mockSession.sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeRef.class))) + .thenReturn(Mono.error(mcpError)); + + // When & Then + StepVerifier.create(exchange.ping()).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class).hasMessage("Server unavailable"); + }); + + verify(mockSession, times(1)).sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeRef.class)); + } + + @Test + void testPingMultipleCalls() { + + when(mockSession.sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeRef.class))) + .thenReturn(Mono.just(Map.of())) + .thenReturn(Mono.just(Map.of())); + + // First call + StepVerifier.create(exchange.ping()).assertNext(result -> { + assertThat(result).isInstanceOf(Map.class); + }).verifyComplete(); + + // Second call + StepVerifier.create(exchange.ping()).assertNext(result -> { + assertThat(result).isInstanceOf(Map.class); + }).verifyComplete(); + + // Verify that sendRequest was called twice + verify(mockSession, times(2)).sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeRef.class)); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/McpCompletionTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/McpCompletionTests.java new file mode 100644 index 000000000..54fb80a78 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/McpCompletionTests.java @@ -0,0 +1,327 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiFunction; + +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.apache.catalina.startup.Tomcat; + +import static org.assertj.core.api.Assertions.assertThat; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.server.transport.HttpServletSseServerTransportProvider; +import io.modelcontextprotocol.server.transport.TomcatTestUtil; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CompleteRequest; +import io.modelcontextprotocol.spec.McpSchema.CompleteResult; +import io.modelcontextprotocol.spec.McpSchema.ErrorCodes; +import io.modelcontextprotocol.spec.McpSchema.InitializeResult; +import io.modelcontextprotocol.spec.McpSchema.Prompt; +import io.modelcontextprotocol.spec.McpSchema.PromptArgument; +import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; +import io.modelcontextprotocol.spec.McpSchema.Resource; +import io.modelcontextprotocol.spec.McpSchema.ResourceReference; +import io.modelcontextprotocol.spec.McpSchema.PromptReference; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpError; + +/** + * Tests for completion functionality with context support. + * + * @author Surbhi Bansal + */ +class McpCompletionTests { + + private HttpServletSseServerTransportProvider mcpServerTransportProvider; + + private static final int PORT = TomcatTestUtil.findAvailablePort(); + + private static final String CUSTOM_MESSAGE_ENDPOINT = "/otherPath/mcp/message"; + + McpClient.SyncSpec clientBuilder; + + private Tomcat tomcat; + + @BeforeEach + public void before() { + // Create and con figure the transport provider + mcpServerTransportProvider = HttpServletSseServerTransportProvider.builder() + .messageEndpoint(CUSTOM_MESSAGE_ENDPOINT) + .build(); + + tomcat = TomcatTestUtil.createTomcatServer("", PORT, mcpServerTransportProvider); + try { + tomcat.start(); + assertThat(tomcat.getServer().getState()).isEqualTo(LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + + this.clientBuilder = McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT).build()); + } + + @AfterEach + public void after() { + if (mcpServerTransportProvider != null) { + mcpServerTransportProvider.closeGracefully().block(); + } + if (tomcat != null) { + try { + tomcat.stop(); + tomcat.destroy(); + } + catch (LifecycleException e) { + e.printStackTrace(); + } + } + } + + @Test + void testCompletionHandlerReceivesContext() { + AtomicReference receivedRequest = new AtomicReference<>(); + BiFunction completionHandler = (exchange, request) -> { + receivedRequest.set(request); + return new CompleteResult(new CompleteResult.CompleteCompletion(List.of("test-completion"), 1, false)); + }; + + ResourceReference resourceRef = new ResourceReference(ResourceReference.TYPE, "test://resource/{param}"); + + var resource = Resource.builder() + .uri("test://resource/{param}") + .name("Test Resource") + .description("A resource for testing") + .mimeType("text/plain") + .size(123L) + .build(); + + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .capabilities(ServerCapabilities.builder().completions().build()) + .resources(new McpServerFeatures.SyncResourceSpecification(resource, + (exchange, req) -> new ReadResourceResult(List.of()))) + .completions(new McpServerFeatures.SyncCompletionSpecification(resourceRef, completionHandler)) + .build(); + + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) + .build();) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Test with context + CompleteRequest request = new CompleteRequest(resourceRef, + new CompleteRequest.CompleteArgument("param", "test"), null, + new CompleteRequest.CompleteContext(Map.of("previous", "value"))); + + CompleteResult result = mcpClient.completeCompletion(request); + + // Verify handler received the context + assertThat(receivedRequest.get().context()).isNotNull(); + assertThat(receivedRequest.get().context().arguments()).containsEntry("previous", "value"); + assertThat(result.completion().values()).containsExactly("test-completion"); + } + + mcpServer.close(); + } + + @Test + void testCompletionBackwardCompatibility() { + AtomicReference contextWasNull = new AtomicReference<>(false); + BiFunction completionHandler = (exchange, request) -> { + contextWasNull.set(request.context() == null); + return new CompleteResult( + new CompleteResult.CompleteCompletion(List.of("no-context-completion"), 1, false)); + }; + + McpSchema.Prompt prompt = new Prompt("test-prompt", "this is a test prompt", + List.of(new PromptArgument("arg", "string", false))); + + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .capabilities(ServerCapabilities.builder().completions().build()) + .prompts(new McpServerFeatures.SyncPromptSpecification(prompt, + (mcpSyncServerExchange, getPromptRequest) -> null)) + .completions(new McpServerFeatures.SyncCompletionSpecification( + new PromptReference(PromptReference.TYPE, "test-prompt"), completionHandler)) + .build(); + + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) + .build();) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Test without context + CompleteRequest request = new CompleteRequest(new PromptReference(PromptReference.TYPE, "test-prompt"), + new CompleteRequest.CompleteArgument("arg", "val")); + + CompleteResult result = mcpClient.completeCompletion(request); + + // Verify context was null + assertThat(contextWasNull.get()).isTrue(); + assertThat(result.completion().values()).containsExactly("no-context-completion"); + } + + mcpServer.close(); + } + + @Test + void testDependentCompletionScenario() { + BiFunction completionHandler = (exchange, request) -> { + // Simulate database/table completion scenario + if (request.ref() instanceof ResourceReference resourceRef) { + if ("db://{database}/{table}".equals(resourceRef.uri())) { + if ("database".equals(request.argument().name())) { + // Complete database names + return new CompleteResult(new CompleteResult.CompleteCompletion( + List.of("users_db", "products_db", "analytics_db"), 3, false)); + } + else if ("table".equals(request.argument().name())) { + // Complete table names based on selected database + if (request.context() != null && request.context().arguments() != null) { + String db = request.context().arguments().get("database"); + if ("users_db".equals(db)) { + return new CompleteResult(new CompleteResult.CompleteCompletion( + List.of("users", "sessions", "permissions"), 3, false)); + } + else if ("products_db".equals(db)) { + return new CompleteResult(new CompleteResult.CompleteCompletion( + List.of("products", "categories", "inventory"), 3, false)); + } + } + } + } + } + return new CompleteResult(new CompleteResult.CompleteCompletion(List.of(), 0, false)); + }; + + McpSchema.Resource resource = Resource.builder() + .uri("db://{database}/{table}") + .name("Database Table") + .description("Resource representing a table in a database") + .mimeType("application/json") + .size(456L) + .build(); + + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .capabilities(ServerCapabilities.builder().completions().build()) + .resources(new McpServerFeatures.SyncResourceSpecification(resource, + (exchange, req) -> new ReadResourceResult(List.of()))) + .completions(new McpServerFeatures.SyncCompletionSpecification( + new ResourceReference(ResourceReference.TYPE, "db://{database}/{table}"), completionHandler)) + .build(); + + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) + .build();) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // First, complete database + CompleteRequest dbRequest = new CompleteRequest( + new ResourceReference(ResourceReference.TYPE, "db://{database}/{table}"), + new CompleteRequest.CompleteArgument("database", "")); + + CompleteResult dbResult = mcpClient.completeCompletion(dbRequest); + assertThat(dbResult.completion().values()).contains("users_db", "products_db"); + + // Then complete table with database context + CompleteRequest tableRequest = new CompleteRequest( + new ResourceReference(ResourceReference.TYPE, "db://{database}/{table}"), + new CompleteRequest.CompleteArgument("table", ""), + new CompleteRequest.CompleteContext(Map.of("database", "users_db"))); + + CompleteResult tableResult = mcpClient.completeCompletion(tableRequest); + assertThat(tableResult.completion().values()).containsExactly("users", "sessions", "permissions"); + + // Different database gives different tables + CompleteRequest tableRequest2 = new CompleteRequest( + new ResourceReference(ResourceReference.TYPE, "db://{database}/{table}"), + new CompleteRequest.CompleteArgument("table", ""), + new CompleteRequest.CompleteContext(Map.of("database", "products_db"))); + + CompleteResult tableResult2 = mcpClient.completeCompletion(tableRequest2); + assertThat(tableResult2.completion().values()).containsExactly("products", "categories", "inventory"); + } + + mcpServer.close(); + } + + @Test + void testCompletionErrorOnMissingContext() { + BiFunction completionHandler = (exchange, request) -> { + if (request.ref() instanceof ResourceReference resourceRef) { + if ("db://{database}/{table}".equals(resourceRef.uri())) { + if ("table".equals(request.argument().name())) { + // Check if database context is provided + if (request.context() == null || request.context().arguments() == null + || !request.context().arguments().containsKey("database")) { + + throw McpError.builder(ErrorCodes.INVALID_REQUEST) + .message("Please select a database first to see available tables") + .build(); + } + // Normal completion if context is provided + String db = request.context().arguments().get("database"); + if ("test_db".equals(db)) { + return new CompleteResult(new CompleteResult.CompleteCompletion( + List.of("users", "orders", "products"), 3, false)); + } + } + } + } + return new CompleteResult(new CompleteResult.CompleteCompletion(List.of(), 0, false)); + }; + + McpSchema.Resource resource = Resource.builder() + .uri("db://{database}/{table}") + .name("Database Table") + .description("Resource representing a table in a database") + .mimeType("application/json") + .size(456L) + .build(); + + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .capabilities(ServerCapabilities.builder().completions().build()) + .resources(new McpServerFeatures.SyncResourceSpecification(resource, + (exchange, req) -> new ReadResourceResult(List.of()))) + .completions(new McpServerFeatures.SyncCompletionSpecification( + new ResourceReference(ResourceReference.TYPE, "db://{database}/{table}"), completionHandler)) + .build(); + + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample" + "client", "0.0.0")) + .build();) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Try to complete table without database context - should raise error + CompleteRequest requestWithoutContext = new CompleteRequest( + new ResourceReference(ResourceReference.TYPE, "db://{database}/{table}"), + new CompleteRequest.CompleteArgument("table", "")); + + assertThatExceptionOfType(McpError.class) + .isThrownBy(() -> mcpClient.completeCompletion(requestWithoutContext)) + .withMessageContaining("Please select a database first"); + + // Now complete with proper context - should work normally + CompleteRequest requestWithContext = new CompleteRequest( + new ResourceReference(ResourceReference.TYPE, "db://{database}/{table}"), + new CompleteRequest.CompleteArgument("table", ""), + new CompleteRequest.CompleteContext(Map.of("database", "test_db"))); + + CompleteResult resultWithContext = mcpClient.completeCompletion(requestWithContext); + assertThat(resultWithContext.completion().values()).containsExactly("users", "orders", "products"); + } + + mcpServer.close(); + } + +} \ No newline at end of file diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/McpServerProtocolVersionTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/McpServerProtocolVersionTests.java similarity index 93% rename from mcp/src/test/java/io/modelcontextprotocol/server/McpServerProtocolVersionTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/server/McpServerProtocolVersionTests.java index f643f1ba3..cdd2bacb7 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/McpServerProtocolVersionTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/McpServerProtocolVersionTests.java @@ -45,7 +45,9 @@ void shouldUseLatestVersionByDefault() { assertThat(jsonResponse.id()).isEqualTo(requestId); assertThat(jsonResponse.result()).isInstanceOf(McpSchema.InitializeResult.class); McpSchema.InitializeResult result = (McpSchema.InitializeResult) jsonResponse.result(); - assertThat(result.protocolVersion()).isEqualTo(McpSchema.LATEST_PROTOCOL_VERSION); + + var protocolVersion = transportProvider.protocolVersions().get(transportProvider.protocolVersions().size() - 1); + assertThat(result.protocolVersion()).isEqualTo(protocolVersion); server.closeGracefully().subscribe(); } @@ -93,7 +95,8 @@ void shouldSuggestLatestVersionForUnsupportedVersion() { assertThat(jsonResponse.id()).isEqualTo(requestId); assertThat(jsonResponse.result()).isInstanceOf(McpSchema.InitializeResult.class); McpSchema.InitializeResult result = (McpSchema.InitializeResult) jsonResponse.result(); - assertThat(result.protocolVersion()).isEqualTo(McpSchema.LATEST_PROTOCOL_VERSION); + var protocolVersion = transportProvider.protocolVersions().get(transportProvider.protocolVersions().size() - 1); + assertThat(result.protocolVersion()).isEqualTo(protocolVersion); server.closeGracefully().subscribe(); } diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/McpSyncServerExchangeTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/McpSyncServerExchangeTests.java new file mode 100644 index 000000000..069d0f896 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/McpSyncServerExchangeTests.java @@ -0,0 +1,692 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.json.TypeRef; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import reactor.core.publisher.Mono; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * Tests for {@link McpSyncServerExchange}. + * + * @author Christian Tzolov + */ +class McpSyncServerExchangeTests { + + @Mock + private McpServerSession mockSession; + + private McpSchema.ClientCapabilities clientCapabilities; + + private McpSchema.Implementation clientInfo; + + private McpAsyncServerExchange asyncExchange; + + private McpSyncServerExchange exchange; + + @BeforeEach + void setUp() { + MockitoAnnotations.openMocks(this); + + clientCapabilities = McpSchema.ClientCapabilities.builder().roots(true).build(); + + clientInfo = new McpSchema.Implementation("test-client", "1.0.0"); + + asyncExchange = new McpAsyncServerExchange(mockSession, clientCapabilities, clientInfo); + exchange = new McpSyncServerExchange(asyncExchange); + } + + @Test + void testListRootsWithSinglePage() { + + List roots = Arrays.asList(new McpSchema.Root("file:///home/user/project1", "Project 1"), + new McpSchema.Root("file:///home/user/project2", "Project 2")); + McpSchema.ListRootsResult singlePageResult = new McpSchema.ListRootsResult(roots, null); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), any(McpSchema.PaginatedRequest.class), + any(TypeRef.class))) + .thenReturn(Mono.just(singlePageResult)); + + McpSchema.ListRootsResult result = exchange.listRoots(); + + assertThat(result.roots()).hasSize(2); + assertThat(result.roots().get(0).uri()).isEqualTo("file:///home/user/project1"); + assertThat(result.roots().get(0).name()).isEqualTo("Project 1"); + assertThat(result.roots().get(1).uri()).isEqualTo("file:///home/user/project2"); + assertThat(result.roots().get(1).name()).isEqualTo("Project 2"); + assertThat(result.nextCursor()).isNull(); + + // Verify that the returned list is unmodifiable + assertThatThrownBy(() -> result.roots().add(new McpSchema.Root("file:///test", "Test"))) + .isInstanceOf(UnsupportedOperationException.class); + } + + @Test + void testListRootsWithMultiplePages() { + + List page1Roots = Arrays.asList(new McpSchema.Root("file:///home/user/project1", "Project 1"), + new McpSchema.Root("file:///home/user/project2", "Project 2")); + List page2Roots = Arrays.asList(new McpSchema.Root("file:///home/user/project3", "Project 3")); + + McpSchema.ListRootsResult page1Result = new McpSchema.ListRootsResult(page1Roots, "cursor1"); + McpSchema.ListRootsResult page2Result = new McpSchema.ListRootsResult(page2Roots, null); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), eq(new McpSchema.PaginatedRequest(null)), + any(TypeRef.class))) + .thenReturn(Mono.just(page1Result)); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), eq(new McpSchema.PaginatedRequest("cursor1")), + any(TypeRef.class))) + .thenReturn(Mono.just(page2Result)); + + McpSchema.ListRootsResult result = exchange.listRoots(); + + assertThat(result.roots()).hasSize(3); + assertThat(result.roots().get(0).uri()).isEqualTo("file:///home/user/project1"); + assertThat(result.roots().get(1).uri()).isEqualTo("file:///home/user/project2"); + assertThat(result.roots().get(2).uri()).isEqualTo("file:///home/user/project3"); + assertThat(result.nextCursor()).isNull(); + + // Verify that the returned list is unmodifiable + assertThatThrownBy(() -> result.roots().add(new McpSchema.Root("file:///test", "Test"))) + .isInstanceOf(UnsupportedOperationException.class); + } + + @Test + void testListRootsWithEmptyResult() { + + McpSchema.ListRootsResult emptyResult = new McpSchema.ListRootsResult(new ArrayList<>(), null); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), any(McpSchema.PaginatedRequest.class), + any(TypeRef.class))) + .thenReturn(Mono.just(emptyResult)); + + McpSchema.ListRootsResult result = exchange.listRoots(); + + assertThat(result.roots()).isEmpty(); + assertThat(result.nextCursor()).isNull(); + + // Verify that the returned list is unmodifiable + assertThatThrownBy(() -> result.roots().add(new McpSchema.Root("file:///test", "Test"))) + .isInstanceOf(UnsupportedOperationException.class); + } + + @Test + void testListRootsWithSpecificCursor() { + + List roots = Arrays.asList(new McpSchema.Root("file:///home/user/project3", "Project 3")); + McpSchema.ListRootsResult result = new McpSchema.ListRootsResult(roots, "nextCursor"); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), eq(new McpSchema.PaginatedRequest("someCursor")), + any(TypeRef.class))) + .thenReturn(Mono.just(result)); + + McpSchema.ListRootsResult listResult = exchange.listRoots("someCursor"); + + assertThat(listResult.roots()).hasSize(1); + assertThat(listResult.roots().get(0).uri()).isEqualTo("file:///home/user/project3"); + assertThat(listResult.nextCursor()).isEqualTo("nextCursor"); + } + + @Test + void testListRootsWithError() { + + when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), any(McpSchema.PaginatedRequest.class), + any(TypeRef.class))) + .thenReturn(Mono.error(new RuntimeException("Network error"))); + + // When & Then + assertThatThrownBy(() -> exchange.listRoots()).isInstanceOf(RuntimeException.class).hasMessage("Network error"); + } + + @Test + void testListRootsUnmodifiabilityAfterAccumulation() { + + List page1Roots = new ArrayList<>( + Arrays.asList(new McpSchema.Root("file:///home/user/project1", "Project 1"))); + List page2Roots = new ArrayList<>( + Arrays.asList(new McpSchema.Root("file:///home/user/project2", "Project 2"))); + + McpSchema.ListRootsResult page1Result = new McpSchema.ListRootsResult(page1Roots, "cursor1"); + McpSchema.ListRootsResult page2Result = new McpSchema.ListRootsResult(page2Roots, null); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), eq(new McpSchema.PaginatedRequest(null)), + any(TypeRef.class))) + .thenReturn(Mono.just(page1Result)); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), eq(new McpSchema.PaginatedRequest("cursor1")), + any(TypeRef.class))) + .thenReturn(Mono.just(page2Result)); + + McpSchema.ListRootsResult result = exchange.listRoots(); + + // Verify the accumulated result is correct + assertThat(result.roots()).hasSize(2); + + // Verify that the returned list is unmodifiable + assertThatThrownBy(() -> result.roots().add(new McpSchema.Root("file:///test", "Test"))) + .isInstanceOf(UnsupportedOperationException.class); + + // Verify that clear() also throws UnsupportedOperationException + assertThatThrownBy(() -> result.roots().clear()).isInstanceOf(UnsupportedOperationException.class); + + // Verify that remove() also throws UnsupportedOperationException + assertThatThrownBy(() -> result.roots().remove(0)).isInstanceOf(UnsupportedOperationException.class); + } + + @Test + void testGetClientCapabilities() { + assertThat(exchange.getClientCapabilities()).isEqualTo(clientCapabilities); + } + + @Test + void testGetClientInfo() { + assertThat(exchange.getClientInfo()).isEqualTo(clientInfo); + } + + // --------------------------------------- + // Logging Notification Tests + // --------------------------------------- + + @Test + void testLoggingNotificationWithNullMessage() { + assertThatThrownBy(() -> exchange.loggingNotification(null)).isInstanceOf(McpError.class) + .hasMessage("Logging message must not be null"); + } + + @Test + void testLoggingNotificationWithAllowedLevel() { + + McpSchema.LoggingMessageNotification notification = McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.ERROR) + .logger("test-logger") + .data("Test error message") + .build(); + + when(mockSession.isNotificationForLevelAllowed(any())).thenReturn(Boolean.TRUE); + when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(notification))) + .thenReturn(Mono.empty()); + + exchange.loggingNotification(notification); + + // Verify that sendNotification was called exactly once + verify(mockSession, times(1)).isNotificationForLevelAllowed(eq(McpSchema.LoggingLevel.ERROR)); + verify(mockSession, times(1)).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(notification)); + } + + @Test + void testLoggingNotificationWithFilteredLevel() { + asyncExchange.setMinLoggingLevel(McpSchema.LoggingLevel.DEBUG); + verify(mockSession, times(1)).setMinLoggingLevel(McpSchema.LoggingLevel.DEBUG); + + McpSchema.LoggingMessageNotification debugNotification = McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.DEBUG) + .logger("test-logger") + .data("Debug message that should be filtered") + .build(); + + when(mockSession.isNotificationForLevelAllowed(eq(McpSchema.LoggingLevel.DEBUG))).thenReturn(Boolean.TRUE); + when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(debugNotification))) + .thenReturn(Mono.empty()); + + exchange.loggingNotification(debugNotification); + + verify(mockSession, times(1)).isNotificationForLevelAllowed(McpSchema.LoggingLevel.DEBUG); + verify(mockSession, times(1)).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), + eq(debugNotification)); + + McpSchema.LoggingMessageNotification warningNotification = McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.WARNING) + .logger("test-logger") + .data("Debug message that should be filtered") + .build(); + + exchange.loggingNotification(warningNotification); + + verify(mockSession, times(1)).isNotificationForLevelAllowed(McpSchema.LoggingLevel.WARNING); + verify(mockSession, never()).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), + eq(warningNotification)); + } + + @Test + void testLoggingNotificationWithSessionError() { + + McpSchema.LoggingMessageNotification notification = McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.ERROR) + .logger("test-logger") + .data("Test error message") + .build(); + + when(mockSession.isNotificationForLevelAllowed(any())).thenReturn(Boolean.TRUE); + when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(notification))) + .thenReturn(Mono.error(new RuntimeException("Session error"))); + + assertThatThrownBy(() -> exchange.loggingNotification(notification)).isInstanceOf(RuntimeException.class) + .hasMessage("Session error"); + } + + // --------------------------------------- + // Create Elicitation Tests + // --------------------------------------- + + @Test + void testCreateElicitationWithNullCapabilities() { + // Given - Create exchange with null capabilities + McpAsyncServerExchange asyncExchangeWithNullCapabilities = new McpAsyncServerExchange(mockSession, null, + clientInfo); + McpSyncServerExchange exchangeWithNullCapabilities = new McpSyncServerExchange( + asyncExchangeWithNullCapabilities); + + McpSchema.ElicitRequest elicitRequest = McpSchema.ElicitRequest.builder() + .message("Please provide your name") + .build(); + + assertThatThrownBy(() -> exchangeWithNullCapabilities.createElicitation(elicitRequest)) + .isInstanceOf(McpError.class) + .hasMessage("Client must be initialized. Call the initialize method first!"); + + // Verify that sendRequest was never called due to null capabilities + verify(mockSession, never()).sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), any(), any(TypeRef.class)); + } + + @Test + void testCreateElicitationWithoutElicitationCapabilities() { + // Given - Create exchange without elicitation capabilities + McpSchema.ClientCapabilities capabilitiesWithoutElicitation = McpSchema.ClientCapabilities.builder() + .roots(true) + .build(); + + McpAsyncServerExchange asyncExchangeWithoutElicitation = new McpAsyncServerExchange(mockSession, + capabilitiesWithoutElicitation, clientInfo); + McpSyncServerExchange exchangeWithoutElicitation = new McpSyncServerExchange(asyncExchangeWithoutElicitation); + + McpSchema.ElicitRequest elicitRequest = McpSchema.ElicitRequest.builder() + .message("Please provide your name") + .build(); + + assertThatThrownBy(() -> exchangeWithoutElicitation.createElicitation(elicitRequest)) + .isInstanceOf(McpError.class) + .hasMessage("Client must be configured with elicitation capabilities"); + + // Verify that sendRequest was never called due to missing elicitation + // capabilities + verify(mockSession, never()).sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), any(), any(TypeRef.class)); + } + + @Test + void testCreateElicitationWithComplexRequest() { + + McpSchema.ClientCapabilities capabilitiesWithElicitation = McpSchema.ClientCapabilities.builder() + .elicitation() + .build(); + + McpAsyncServerExchange asyncExchangeWithElicitation = new McpAsyncServerExchange(mockSession, + capabilitiesWithElicitation, clientInfo); + McpSyncServerExchange exchangeWithElicitation = new McpSyncServerExchange(asyncExchangeWithElicitation); + + // Create a complex elicit request with schema + java.util.Map requestedSchema = new java.util.HashMap<>(); + requestedSchema.put("type", "object"); + requestedSchema.put("properties", java.util.Map.of("name", java.util.Map.of("type", "string"), "age", + java.util.Map.of("type", "number"))); + requestedSchema.put("required", java.util.List.of("name")); + + McpSchema.ElicitRequest elicitRequest = McpSchema.ElicitRequest.builder() + .message("Please provide your personal information") + .requestedSchema(requestedSchema) + .build(); + + java.util.Map responseContent = new java.util.HashMap<>(); + responseContent.put("name", "John Doe"); + responseContent.put("age", 30); + + McpSchema.ElicitResult expectedResult = McpSchema.ElicitResult.builder() + .message(McpSchema.ElicitResult.Action.ACCEPT) + .content(responseContent) + .build(); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), eq(elicitRequest), any(TypeRef.class))) + .thenReturn(Mono.just(expectedResult)); + + McpSchema.ElicitResult result = exchangeWithElicitation.createElicitation(elicitRequest); + + assertThat(result).isEqualTo(expectedResult); + assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); + assertThat(result.content()).isNotNull(); + assertThat(result.content().get("name")).isEqualTo("John Doe"); + assertThat(result.content().get("age")).isEqualTo(30); + } + + @Test + void testCreateElicitationWithDeclineAction() { + + McpSchema.ClientCapabilities capabilitiesWithElicitation = McpSchema.ClientCapabilities.builder() + .elicitation() + .build(); + + McpAsyncServerExchange asyncExchangeWithElicitation = new McpAsyncServerExchange(mockSession, + capabilitiesWithElicitation, clientInfo); + McpSyncServerExchange exchangeWithElicitation = new McpSyncServerExchange(asyncExchangeWithElicitation); + + McpSchema.ElicitRequest elicitRequest = McpSchema.ElicitRequest.builder() + .message("Please provide sensitive information") + .build(); + + McpSchema.ElicitResult expectedResult = McpSchema.ElicitResult.builder() + .message(McpSchema.ElicitResult.Action.DECLINE) + .build(); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), eq(elicitRequest), any(TypeRef.class))) + .thenReturn(Mono.just(expectedResult)); + + McpSchema.ElicitResult result = exchangeWithElicitation.createElicitation(elicitRequest); + + assertThat(result).isEqualTo(expectedResult); + assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.DECLINE); + } + + @Test + void testCreateElicitationWithCancelAction() { + + McpSchema.ClientCapabilities capabilitiesWithElicitation = McpSchema.ClientCapabilities.builder() + .elicitation() + .build(); + + McpAsyncServerExchange asyncExchangeWithElicitation = new McpAsyncServerExchange(mockSession, + capabilitiesWithElicitation, clientInfo); + McpSyncServerExchange exchangeWithElicitation = new McpSyncServerExchange(asyncExchangeWithElicitation); + + McpSchema.ElicitRequest elicitRequest = McpSchema.ElicitRequest.builder() + .message("Please provide your information") + .build(); + + McpSchema.ElicitResult expectedResult = McpSchema.ElicitResult.builder() + .message(McpSchema.ElicitResult.Action.CANCEL) + .build(); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), eq(elicitRequest), any(TypeRef.class))) + .thenReturn(Mono.just(expectedResult)); + + McpSchema.ElicitResult result = exchangeWithElicitation.createElicitation(elicitRequest); + + assertThat(result).isEqualTo(expectedResult); + assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.CANCEL); + } + + @Test + void testCreateElicitationWithSessionError() { + + McpSchema.ClientCapabilities capabilitiesWithElicitation = McpSchema.ClientCapabilities.builder() + .elicitation() + .build(); + + McpAsyncServerExchange asyncExchangeWithElicitation = new McpAsyncServerExchange(mockSession, + capabilitiesWithElicitation, clientInfo); + McpSyncServerExchange exchangeWithElicitation = new McpSyncServerExchange(asyncExchangeWithElicitation); + + McpSchema.ElicitRequest elicitRequest = McpSchema.ElicitRequest.builder() + .message("Please provide your name") + .build(); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), eq(elicitRequest), any(TypeRef.class))) + .thenReturn(Mono.error(new RuntimeException("Session communication error"))); + + assertThatThrownBy(() -> exchangeWithElicitation.createElicitation(elicitRequest)) + .isInstanceOf(RuntimeException.class) + .hasMessage("Session communication error"); + } + + // --------------------------------------- + // Create Message Tests + // --------------------------------------- + + @Test + void testCreateMessageWithNullCapabilities() { + + McpAsyncServerExchange asyncExchangeWithNullCapabilities = new McpAsyncServerExchange(mockSession, null, + clientInfo); + McpSyncServerExchange exchangeWithNullCapabilities = new McpSyncServerExchange( + asyncExchangeWithNullCapabilities); + + McpSchema.CreateMessageRequest createMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(Arrays + .asList(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Hello, world!")))) + .build(); + + assertThatThrownBy(() -> exchangeWithNullCapabilities.createMessage(createMessageRequest)) + .isInstanceOf(McpError.class) + .hasMessage("Client must be initialized. Call the initialize method first!"); + + // Verify that sendRequest was never called due to null capabilities + verify(mockSession, never()).sendRequest(eq(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE), any(), + any(TypeRef.class)); + } + + @Test + void testCreateMessageWithoutSamplingCapabilities() { + + McpSchema.ClientCapabilities capabilitiesWithoutSampling = McpSchema.ClientCapabilities.builder() + .roots(true) + .build(); + + McpAsyncServerExchange asyncExchangeWithoutSampling = new McpAsyncServerExchange(mockSession, + capabilitiesWithoutSampling, clientInfo); + McpSyncServerExchange exchangeWithoutSampling = new McpSyncServerExchange(asyncExchangeWithoutSampling); + + McpSchema.CreateMessageRequest createMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(Arrays + .asList(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Hello, world!")))) + .build(); + + assertThatThrownBy(() -> exchangeWithoutSampling.createMessage(createMessageRequest)) + .isInstanceOf(McpError.class) + .hasMessage("Client must be configured with sampling capabilities"); + + // Verify that sendRequest was never called due to missing sampling capabilities + verify(mockSession, never()).sendRequest(eq(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE), any(), + any(TypeRef.class)); + } + + @Test + void testCreateMessageWithBasicRequest() { + + McpSchema.ClientCapabilities capabilitiesWithSampling = McpSchema.ClientCapabilities.builder() + .sampling() + .build(); + + McpAsyncServerExchange asyncExchangeWithSampling = new McpAsyncServerExchange(mockSession, + capabilitiesWithSampling, clientInfo); + McpSyncServerExchange exchangeWithSampling = new McpSyncServerExchange(asyncExchangeWithSampling); + + McpSchema.CreateMessageRequest createMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(Arrays + .asList(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Hello, world!")))) + .build(); + + McpSchema.CreateMessageResult expectedResult = McpSchema.CreateMessageResult.builder() + .role(McpSchema.Role.ASSISTANT) + .content(new McpSchema.TextContent("Hello! How can I help you today?")) + .model("gpt-4") + .stopReason(McpSchema.CreateMessageResult.StopReason.END_TURN) + .build(); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE), eq(createMessageRequest), + any(TypeRef.class))) + .thenReturn(Mono.just(expectedResult)); + + McpSchema.CreateMessageResult result = exchangeWithSampling.createMessage(createMessageRequest); + + assertThat(result).isEqualTo(expectedResult); + assertThat(result.role()).isEqualTo(McpSchema.Role.ASSISTANT); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Hello! How can I help you today?"); + assertThat(result.model()).isEqualTo("gpt-4"); + assertThat(result.stopReason()).isEqualTo(McpSchema.CreateMessageResult.StopReason.END_TURN); + } + + @Test + void testCreateMessageWithImageContent() { + + McpSchema.ClientCapabilities capabilitiesWithSampling = McpSchema.ClientCapabilities.builder() + .sampling() + .build(); + + McpAsyncServerExchange asyncExchangeWithSampling = new McpAsyncServerExchange(mockSession, + capabilitiesWithSampling, clientInfo); + McpSyncServerExchange exchangeWithSampling = new McpSyncServerExchange(asyncExchangeWithSampling); + + // Create request with image content + McpSchema.CreateMessageRequest createMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(Arrays.asList(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.ImageContent(null, "data:image/jpeg;base64,/9j/4AAQSkZJRgABAQEAYABgAAD...", + "image/jpeg")))) + .build(); + + McpSchema.CreateMessageResult expectedResult = McpSchema.CreateMessageResult.builder() + .role(McpSchema.Role.ASSISTANT) + .content(new McpSchema.TextContent("I can see an image. It appears to be a photograph.")) + .model("gpt-4-vision") + .stopReason(McpSchema.CreateMessageResult.StopReason.END_TURN) + .build(); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE), eq(createMessageRequest), + any(TypeRef.class))) + .thenReturn(Mono.just(expectedResult)); + + McpSchema.CreateMessageResult result = exchangeWithSampling.createMessage(createMessageRequest); + + assertThat(result).isEqualTo(expectedResult); + assertThat(result.role()).isEqualTo(McpSchema.Role.ASSISTANT); + assertThat(result.model()).isEqualTo("gpt-4-vision"); + } + + @Test + void testCreateMessageWithSessionError() { + + McpSchema.ClientCapabilities capabilitiesWithSampling = McpSchema.ClientCapabilities.builder() + .sampling() + .build(); + + McpAsyncServerExchange asyncExchangeWithSampling = new McpAsyncServerExchange(mockSession, + capabilitiesWithSampling, clientInfo); + McpSyncServerExchange exchangeWithSampling = new McpSyncServerExchange(asyncExchangeWithSampling); + + McpSchema.CreateMessageRequest createMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(Arrays + .asList(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Hello")))) + .build(); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE), eq(createMessageRequest), + any(TypeRef.class))) + .thenReturn(Mono.error(new RuntimeException("Session communication error"))); + + assertThatThrownBy(() -> exchangeWithSampling.createMessage(createMessageRequest)) + .isInstanceOf(RuntimeException.class) + .hasMessage("Session communication error"); + } + + @Test + void testCreateMessageWithIncludeContext() { + + McpSchema.ClientCapabilities capabilitiesWithSampling = McpSchema.ClientCapabilities.builder() + .sampling() + .build(); + + McpAsyncServerExchange asyncExchangeWithSampling = new McpAsyncServerExchange(mockSession, + capabilitiesWithSampling, clientInfo); + McpSyncServerExchange exchangeWithSampling = new McpSyncServerExchange(asyncExchangeWithSampling); + + McpSchema.CreateMessageRequest createMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(Arrays.asList(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("What files are available?")))) + .includeContext(McpSchema.CreateMessageRequest.ContextInclusionStrategy.ALL_SERVERS) + .build(); + + McpSchema.CreateMessageResult expectedResult = McpSchema.CreateMessageResult.builder() + .role(McpSchema.Role.ASSISTANT) + .content(new McpSchema.TextContent("Based on the available context, I can see several files...")) + .model("gpt-4") + .stopReason(McpSchema.CreateMessageResult.StopReason.END_TURN) + .build(); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE), eq(createMessageRequest), + any(TypeRef.class))) + .thenReturn(Mono.just(expectedResult)); + + McpSchema.CreateMessageResult result = exchangeWithSampling.createMessage(createMessageRequest); + + assertThat(result).isEqualTo(expectedResult); + assertThat(((McpSchema.TextContent) result.content()).text()).contains("context"); + } + + // --------------------------------------- + // Ping Tests + // --------------------------------------- + + @Test + void testPingWithSuccessfulResponse() { + + java.util.Map expectedResponse = java.util.Map.of(); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeRef.class))) + .thenReturn(Mono.just(expectedResponse)); + + exchange.ping(); + + // Verify that sendRequest was called with correct parameters + verify(mockSession, times(1)).sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeRef.class)); + } + + @Test + void testPingWithMcpError() { + // Given - Mock an MCP-specific error during ping + McpError mcpError = new McpError("Server unavailable"); + when(mockSession.sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeRef.class))) + .thenReturn(Mono.error(mcpError)); + + // When & Then + assertThatThrownBy(() -> exchange.ping()).isInstanceOf(McpError.class).hasMessage("Server unavailable"); + + verify(mockSession, times(1)).sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeRef.class)); + } + + @Test + void testPingMultipleCalls() { + + when(mockSession.sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeRef.class))) + .thenReturn(Mono.just(Map.of())) + .thenReturn(Mono.just(Map.of())); + + // First call + exchange.ping(); + + // Second call + exchange.ping(); + + // Verify that sendRequest was called twice + verify(mockSession, times(2)).sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeRef.class)); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/ResourceTemplateListingTest.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/ResourceTemplateListingTest.java new file mode 100644 index 000000000..61703c306 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/ResourceTemplateListingTest.java @@ -0,0 +1,100 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import io.modelcontextprotocol.spec.McpSchema; +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.stream.Collectors; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Test to verify the separation of regular resources and resource templates. Regular + * resources (without template parameters) should only appear in resources/list. Template + * resources (containing {}) should only appear in resources/templates/list. + */ +public class ResourceTemplateListingTest { + + @Test + void testTemplateResourcesFilteredFromRegularListing() { + // The change we made filters resources containing "{" from the regular listing + // This test verifies that behavior is working correctly + + // Given a string with template parameter + String templateUri = "file:///test/{userId}/profile.txt"; + assertThat(templateUri.contains("{")).isTrue(); + + // And a regular URI + String regularUri = "file:///test/regular.txt"; + assertThat(regularUri.contains("{")).isFalse(); + + // The filter should exclude template URIs + assertThat(!templateUri.contains("{")).isFalse(); + assertThat(!regularUri.contains("{")).isTrue(); + } + + @Test + void testResourceListingWithMixedResources() { + // Create resource list with both regular and template resources + List allResources = List.of( + new McpSchema.Resource("file:///test/doc1.txt", "Document 1", "text/plain", null, null), + new McpSchema.Resource("file:///test/doc2.txt", "Document 2", "text/plain", null, null), + new McpSchema.Resource("file:///test/{type}/document.txt", "Typed Document", "text/plain", null, null), + new McpSchema.Resource("file:///users/{userId}/files/{fileId}", "User File", "text/plain", null, null)); + + // Apply the filter logic from McpAsyncServer line 438 + List filteredResources = allResources.stream() + .filter(resource -> !resource.uri().contains("{")) + .collect(Collectors.toList()); + + // Verify only regular resources are included + assertThat(filteredResources).hasSize(2); + assertThat(filteredResources).extracting(McpSchema.Resource::uri) + .containsExactlyInAnyOrder("file:///test/doc1.txt", "file:///test/doc2.txt"); + } + + @Test + void testResourceTemplatesListedSeparately() { + // Create mixed resources + List resources = List.of( + new McpSchema.Resource("file:///test/regular.txt", "Regular Resource", "text/plain", null, null), + new McpSchema.Resource("file:///test/user/{userId}/profile.txt", "User Profile", "text/plain", null, + null)); + + // Create explicit resource template + McpSchema.ResourceTemplate explicitTemplate = new McpSchema.ResourceTemplate( + "file:///test/document/{docId}/content.txt", "Document Template", null, "text/plain", null); + + // Filter regular resources (those without template parameters) + List regularResources = resources.stream() + .filter(resource -> !resource.uri().contains("{")) + .collect(Collectors.toList()); + + // Extract template resources (those with template parameters) + List templateResources = resources.stream() + .filter(resource -> resource.uri().contains("{")) + .map(resource -> new McpSchema.ResourceTemplate(resource.uri(), resource.name(), resource.description(), + resource.mimeType(), resource.annotations())) + .collect(Collectors.toList()); + + // Verify regular resources list + assertThat(regularResources).hasSize(1); + assertThat(regularResources.get(0).uri()).isEqualTo("file:///test/regular.txt"); + + // Verify template resources list includes both extracted and explicit templates + assertThat(templateResources).hasSize(1); + assertThat(templateResources.get(0).uriTemplate()).isEqualTo("file:///test/user/{userId}/profile.txt"); + + // In the actual implementation, both would be combined + List allTemplates = List.of(templateResources.get(0), explicitTemplate); + assertThat(allTemplates).hasSize(2); + assertThat(allTemplates).extracting(McpSchema.ResourceTemplate::uriTemplate) + .containsExactlyInAnyOrder("file:///test/user/{userId}/profile.txt", + "file:///test/document/{docId}/content.txt"); + } + +} \ No newline at end of file diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/ResourceTemplateManagementTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/ResourceTemplateManagementTests.java new file mode 100644 index 000000000..b7d46a967 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/ResourceTemplateManagementTests.java @@ -0,0 +1,299 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import java.time.Duration; +import java.util.List; + +import io.modelcontextprotocol.MockMcpServerTransport; +import io.modelcontextprotocol.MockMcpServerTransportProvider; +import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; +import io.modelcontextprotocol.spec.McpSchema.ResourceTemplate; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; + +/** + * Test suite for Resource Template Management functionality. Tests the new + * addResourceTemplate() and removeResourceTemplate() methods, as well as the Map-based + * resource template storage. + * + * @author Christian Tzolov + */ +public class ResourceTemplateManagementTests { + + private static final String TEST_TEMPLATE_URI = "test://resource/{param}"; + + private static final String TEST_TEMPLATE_NAME = "test-template"; + + private MockMcpServerTransportProvider mockTransportProvider; + + private McpAsyncServer mcpAsyncServer; + + @BeforeEach + void setUp() { + mockTransportProvider = new MockMcpServerTransportProvider(new MockMcpServerTransport()); + } + + @AfterEach + void tearDown() { + if (mcpAsyncServer != null) { + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))) + .doesNotThrowAnyException(); + } + } + + // --------------------------------------- + // Async Resource Template Tests + // --------------------------------------- + + @Test + void testAddResourceTemplate() { + mcpAsyncServer = McpServer.async(mockTransportProvider) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + ResourceTemplate template = ResourceTemplate.builder() + .uriTemplate(TEST_TEMPLATE_URI) + .name(TEST_TEMPLATE_NAME) + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.AsyncResourceTemplateSpecification specification = new McpServerFeatures.AsyncResourceTemplateSpecification( + template, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + StepVerifier.create(mcpAsyncServer.addResourceTemplate(specification)).verifyComplete(); + } + + @Test + void testAddResourceTemplateWithoutCapability() { + // Create a server without resource capabilities + McpAsyncServer serverWithoutResources = McpServer.async(mockTransportProvider) + .serverInfo("test-server", "1.0.0") + .build(); + + ResourceTemplate template = ResourceTemplate.builder() + .uriTemplate(TEST_TEMPLATE_URI) + .name(TEST_TEMPLATE_NAME) + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.AsyncResourceTemplateSpecification specification = new McpServerFeatures.AsyncResourceTemplateSpecification( + template, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + StepVerifier.create(serverWithoutResources.addResourceTemplate(specification)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Server must be configured with resource capabilities"); + }); + + assertThatCode(() -> serverWithoutResources.closeGracefully().block(Duration.ofSeconds(10))) + .doesNotThrowAnyException(); + } + + @Test + void testRemoveResourceTemplate() { + ResourceTemplate template = ResourceTemplate.builder() + .uriTemplate(TEST_TEMPLATE_URI) + .name(TEST_TEMPLATE_NAME) + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.AsyncResourceTemplateSpecification specification = new McpServerFeatures.AsyncResourceTemplateSpecification( + template, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + mcpAsyncServer = McpServer.async(mockTransportProvider) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .resourceTemplates(specification) + .build(); + + StepVerifier.create(mcpAsyncServer.removeResourceTemplate(TEST_TEMPLATE_URI)).verifyComplete(); + } + + @Test + void testRemoveResourceTemplateWithoutCapability() { + // Create a server without resource capabilities + McpAsyncServer serverWithoutResources = McpServer.async(mockTransportProvider) + .serverInfo("test-server", "1.0.0") + .build(); + + StepVerifier.create(serverWithoutResources.removeResourceTemplate(TEST_TEMPLATE_URI)) + .verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Server must be configured with resource capabilities"); + }); + + assertThatCode(() -> serverWithoutResources.closeGracefully().block(Duration.ofSeconds(10))) + .doesNotThrowAnyException(); + } + + @Test + void testRemoveNonexistentResourceTemplate() { + mcpAsyncServer = McpServer.async(mockTransportProvider) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + // Removing a non-existent resource template should complete successfully (no + // error) + // as per the new implementation that just logs a warning + StepVerifier.create(mcpAsyncServer.removeResourceTemplate("nonexistent://template/{id}")).verifyComplete(); + } + + @Test + void testReplaceExistingResourceTemplate() { + ResourceTemplate originalTemplate = ResourceTemplate.builder() + .uriTemplate(TEST_TEMPLATE_URI) + .name(TEST_TEMPLATE_NAME) + .description("Original template") + .mimeType("text/plain") + .build(); + + ResourceTemplate updatedTemplate = ResourceTemplate.builder() + .uriTemplate(TEST_TEMPLATE_URI) + .name(TEST_TEMPLATE_NAME) + .description("Updated template") + .mimeType("application/json") + .build(); + + McpServerFeatures.AsyncResourceTemplateSpecification originalSpec = new McpServerFeatures.AsyncResourceTemplateSpecification( + originalTemplate, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + McpServerFeatures.AsyncResourceTemplateSpecification updatedSpec = new McpServerFeatures.AsyncResourceTemplateSpecification( + updatedTemplate, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + mcpAsyncServer = McpServer.async(mockTransportProvider) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .resourceTemplates(originalSpec) + .build(); + + // Adding a resource template with the same URI should replace the existing one + StepVerifier.create(mcpAsyncServer.addResourceTemplate(updatedSpec)).verifyComplete(); + } + + // --------------------------------------- + // Sync Resource Template Tests + // --------------------------------------- + + @Test + void testSyncAddResourceTemplate() { + ResourceTemplate template = ResourceTemplate.builder() + .uriTemplate(TEST_TEMPLATE_URI) + .name(TEST_TEMPLATE_NAME) + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.SyncResourceTemplateSpecification specification = new McpServerFeatures.SyncResourceTemplateSpecification( + template, (exchange, req) -> new ReadResourceResult(List.of())); + + var mcpSyncServer = McpServer.sync(mockTransportProvider) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + assertThatCode(() -> mcpSyncServer.addResourceTemplate(specification)).doesNotThrowAnyException(); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + @Test + void testSyncRemoveResourceTemplate() { + ResourceTemplate template = ResourceTemplate.builder() + .uriTemplate(TEST_TEMPLATE_URI) + .name(TEST_TEMPLATE_NAME) + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.SyncResourceTemplateSpecification specification = new McpServerFeatures.SyncResourceTemplateSpecification( + template, (exchange, req) -> new ReadResourceResult(List.of())); + + var mcpSyncServer = McpServer.sync(mockTransportProvider) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .resourceTemplates(specification) + .build(); + + assertThatCode(() -> mcpSyncServer.removeResourceTemplate(TEST_TEMPLATE_URI)).doesNotThrowAnyException(); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Map-based Storage Tests + // --------------------------------------- + + @Test + void testResourceTemplateMapBasedStorage() { + ResourceTemplate template1 = ResourceTemplate.builder() + .uriTemplate("test://template1/{id}") + .name("template1") + .description("First template") + .mimeType("text/plain") + .build(); + + ResourceTemplate template2 = ResourceTemplate.builder() + .uriTemplate("test://template2/{id}") + .name("template2") + .description("Second template") + .mimeType("application/json") + .build(); + + McpServerFeatures.AsyncResourceTemplateSpecification spec1 = new McpServerFeatures.AsyncResourceTemplateSpecification( + template1, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + McpServerFeatures.AsyncResourceTemplateSpecification spec2 = new McpServerFeatures.AsyncResourceTemplateSpecification( + template2, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + mcpAsyncServer = McpServer.async(mockTransportProvider) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .resourceTemplates(spec1, spec2) + .build(); + + // Verify both templates are stored (this would be tested through integration + // tests + // or by accessing internal state, but for unit tests we verify no exceptions) + assertThat(mcpAsyncServer).isNotNull(); + } + + @Test + void testResourceTemplateBuilderWithMap() { + // Test that the new Map-based builder methods work correctly + ResourceTemplate template = ResourceTemplate.builder() + .uriTemplate(TEST_TEMPLATE_URI) + .name(TEST_TEMPLATE_NAME) + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.AsyncResourceTemplateSpecification specification = new McpServerFeatures.AsyncResourceTemplateSpecification( + template, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + // Test varargs builder method + assertThatCode(() -> { + McpServer.async(mockTransportProvider) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .resourceTemplates(specification) + .build() + .closeGracefully() + .block(Duration.ofSeconds(10)); + }).doesNotThrowAnyException(); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpAsyncServerTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/ServletSseMcpAsyncServerTests.java similarity index 84% rename from mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpAsyncServerTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/server/ServletSseMcpAsyncServerTests.java index 81d904292..8906adfe0 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpAsyncServerTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/ServletSseMcpAsyncServerTests.java @@ -16,9 +16,13 @@ @Timeout(15) // Giving extra time beyond the client timeout class ServletSseMcpAsyncServerTests extends AbstractMcpAsyncServerTests { - @Override protected McpServerTransportProvider createMcpTransportProvider() { return HttpServletSseServerTransportProvider.builder().messageEndpoint("/mcp/message").build(); } + @Override + protected McpServer.AsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(createMcpTransportProvider()); + } + } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpSyncServerTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/ServletSseMcpSyncServerTests.java similarity index 85% rename from mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpSyncServerTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/server/ServletSseMcpSyncServerTests.java index 154cf3a61..7b77f9241 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpSyncServerTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/ServletSseMcpSyncServerTests.java @@ -16,9 +16,13 @@ @Timeout(15) // Giving extra time beyond the client timeout class ServletSseMcpSyncServerTests extends AbstractMcpSyncServerTests { - @Override protected McpServerTransportProvider createMcpTransportProvider() { return HttpServletSseServerTransportProvider.builder().messageEndpoint("/mcp/message").build(); } + @Override + protected McpServer.SyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(createMcpTransportProvider()); + } + } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpAsyncServerTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/StdioMcpAsyncServerTests.java similarity index 69% rename from mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpAsyncServerTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/server/StdioMcpAsyncServerTests.java index 0381a43bd..b2dfbea25 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpAsyncServerTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/StdioMcpAsyncServerTests.java @@ -8,6 +8,8 @@ import io.modelcontextprotocol.spec.McpServerTransportProvider; import org.junit.jupiter.api.Timeout; +import static io.modelcontextprotocol.util.McpJsonMapperUtils.JSON_MAPPER; + /** * Tests for {@link McpAsyncServer} using {@link StdioServerTransport}. * @@ -16,9 +18,13 @@ @Timeout(15) // Giving extra time beyond the client timeout class StdioMcpAsyncServerTests extends AbstractMcpAsyncServerTests { - @Override protected McpServerTransportProvider createMcpTransportProvider() { - return new StdioServerTransportProvider(); + return new StdioServerTransportProvider(JSON_MAPPER); + } + + @Override + protected McpServer.AsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(createMcpTransportProvider()); } } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpSyncServerTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/StdioMcpSyncServerTests.java similarity index 70% rename from mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpSyncServerTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/server/StdioMcpSyncServerTests.java index a71c38493..c97c75d38 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpSyncServerTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/StdioMcpSyncServerTests.java @@ -8,6 +8,8 @@ import io.modelcontextprotocol.spec.McpServerTransportProvider; import org.junit.jupiter.api.Timeout; +import static io.modelcontextprotocol.util.McpJsonMapperUtils.JSON_MAPPER; + /** * Tests for {@link McpSyncServer} using {@link StdioServerTransportProvider}. * @@ -16,9 +18,13 @@ @Timeout(15) // Giving extra time beyond the client timeout class StdioMcpSyncServerTests extends AbstractMcpSyncServerTests { - @Override protected McpServerTransportProvider createMcpTransportProvider() { - return new StdioServerTransportProvider(); + return new StdioServerTransportProvider(JSON_MAPPER); + } + + @Override + protected McpServer.SyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(createMcpTransportProvider()); } } diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/SyncToolSpecificationBuilderTest.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/SyncToolSpecificationBuilderTest.java new file mode 100644 index 000000000..9bcd2bc84 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/SyncToolSpecificationBuilderTest.java @@ -0,0 +1,105 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import static io.modelcontextprotocol.util.ToolsUtils.EMPTY_JSON_SCHEMA; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.util.List; +import java.util.Map; + +import org.junit.jupiter.api.Test; + +import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.TextContent; +import io.modelcontextprotocol.spec.McpSchema.Tool; + +/** + * Tests for {@link McpServerFeatures.SyncToolSpecification.Builder}. + * + * @author Christian Tzolov + */ +class SyncToolSpecificationBuilderTest { + + @Test + void builderShouldCreateValidSyncToolSpecification() { + + Tool tool = Tool.builder().name("test-tool").title("A test tool").inputSchema(EMPTY_JSON_SCHEMA).build(); + + McpServerFeatures.SyncToolSpecification specification = McpServerFeatures.SyncToolSpecification.builder() + .tool(tool) + .callHandler((exchange, request) -> CallToolResult.builder() + .content(List.of(new TextContent("Test result"))) + .isError(false) + .build()) + .build(); + + assertThat(specification).isNotNull(); + assertThat(specification.tool()).isEqualTo(tool); + assertThat(specification.callHandler()).isNotNull(); + assertThat(specification.call()).isNull(); // deprecated field should be null + } + + @Test + void builderShouldThrowExceptionWhenToolIsNull() { + assertThatThrownBy(() -> McpServerFeatures.SyncToolSpecification.builder() + .callHandler((exchange, request) -> CallToolResult.builder().content(List.of()).isError(false).build()) + .build()).isInstanceOf(IllegalArgumentException.class).hasMessage("Tool must not be null"); + } + + @Test + void builderShouldThrowExceptionWhenCallToolIsNull() { + Tool tool = Tool.builder().name("test-tool").description("A test tool").inputSchema(EMPTY_JSON_SCHEMA).build(); + + assertThatThrownBy(() -> McpServerFeatures.SyncToolSpecification.builder().tool(tool).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("CallTool function must not be null"); + } + + @Test + void builderShouldAllowMethodChaining() { + Tool tool = Tool.builder().name("test-tool").description("A test tool").inputSchema(EMPTY_JSON_SCHEMA).build(); + McpServerFeatures.SyncToolSpecification.Builder builder = McpServerFeatures.SyncToolSpecification.builder(); + + // Then - verify method chaining returns the same builder instance + assertThat(builder.tool(tool)).isSameAs(builder); + assertThat(builder + .callHandler((exchange, request) -> CallToolResult.builder().content(List.of()).isError(false).build())) + .isSameAs(builder); + } + + @Test + void builtSpecificationShouldExecuteCallToolCorrectly() { + Tool tool = Tool.builder() + .name("calculator") + .description("Simple calculator") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + String expectedResult = "42"; + + McpServerFeatures.SyncToolSpecification specification = McpServerFeatures.SyncToolSpecification.builder() + .tool(tool) + .callHandler((exchange, request) -> { + // Simple test implementation + return CallToolResult.builder() + .content(List.of(new TextContent(expectedResult))) + .isError(false) + .build(); + }) + .build(); + + CallToolRequest request = new CallToolRequest("calculator", Map.of()); + CallToolResult result = specification.callHandler().apply(null, request); + + assertThat(result).isNotNull(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo(expectedResult); + assertThat(result.isError()).isFalse(); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerCustomContextPathTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerCustomContextPathTests.java similarity index 96% rename from mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerCustomContextPathTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerCustomContextPathTests.java index 2cd62889a..be88097b3 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerCustomContextPathTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerCustomContextPathTests.java @@ -1,9 +1,8 @@ /* * Copyright 2024 - 2024 the original author or authors. */ -package io.modelcontextprotocol.server.transport; -import com.fasterxml.jackson.databind.ObjectMapper; +package io.modelcontextprotocol.server.transport; import io.modelcontextprotocol.client.McpClient; import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; @@ -39,7 +38,6 @@ public void before() { // Create and configure the transport provider mcpServerTransportProvider = HttpServletSseServerTransportProvider.builder() - .objectMapper(new ObjectMapper()) .baseUrl(CUSTOM_CONTEXT_PATH) .messageEndpoint(CUSTOM_MESSAGE_ENDPOINT) .sseEndpoint(CUSTOM_SSE_ENDPOINT) diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/McpTestRequestRecordingServletFilter.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/McpTestRequestRecordingServletFilter.java new file mode 100644 index 000000000..b94552d12 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/McpTestRequestRecordingServletFilter.java @@ -0,0 +1,128 @@ +/* + * Copyright 2025 - 2025 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import java.io.BufferedReader; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.function.Function; +import java.util.stream.Collectors; + +import jakarta.servlet.Filter; +import jakarta.servlet.FilterChain; +import jakarta.servlet.ReadListener; +import jakarta.servlet.ServletException; +import jakarta.servlet.ServletInputStream; +import jakarta.servlet.ServletRequest; +import jakarta.servlet.ServletResponse; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletRequestWrapper; + +/** + * Simple {@link Filter} which records calls made to an MCP server. + * + * @author Daniel Garnier-Moiroux + */ +public class McpTestRequestRecordingServletFilter implements Filter { + + private final List calls = new ArrayList<>(); + + @Override + public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) + throws IOException, ServletException { + + if (servletRequest instanceof HttpServletRequest req) { + var headers = Collections.list(req.getHeaderNames()) + .stream() + .collect(Collectors.toUnmodifiableMap(Function.identity(), + name -> String.join(",", Collections.list(req.getHeaders(name))))); + var request = new CachedBodyHttpServletRequest(req); + calls.add(new Call(req.getMethod(), headers, request.getBodyAsString())); + filterChain.doFilter(request, servletResponse); + } + else { + filterChain.doFilter(servletRequest, servletResponse); + } + + } + + public List getCalls() { + + return List.copyOf(calls); + } + + public record Call(String method, Map headers, String body) { + + } + + public static class CachedBodyHttpServletRequest extends HttpServletRequestWrapper { + + private final byte[] cachedBody; + + public CachedBodyHttpServletRequest(HttpServletRequest request) throws IOException { + super(request); + this.cachedBody = request.getInputStream().readAllBytes(); + } + + @Override + public ServletInputStream getInputStream() { + return new CachedBodyServletInputStream(cachedBody); + } + + @Override + public BufferedReader getReader() { + return new BufferedReader(new InputStreamReader(getInputStream(), StandardCharsets.UTF_8)); + } + + public String getBodyAsString() { + return new String(cachedBody, StandardCharsets.UTF_8); + } + + } + + public static class CachedBodyServletInputStream extends ServletInputStream { + + private InputStream cachedBodyInputStream; + + public CachedBodyServletInputStream(byte[] cachedBody) { + this.cachedBodyInputStream = new ByteArrayInputStream(cachedBody); + } + + @Override + public boolean isFinished() { + try { + return cachedBodyInputStream.available() == 0; + } + catch (IOException e) { + e.printStackTrace(); + } + return false; + } + + @Override + public boolean isReady() { + return true; + } + + @Override + public void setReadListener(ReadListener readListener) { + throw new UnsupportedOperationException(); + } + + @Override + public int read() throws IOException { + return cachedBodyInputStream.read(); + } + + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java similarity index 93% rename from mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java index 14987b5ac..6a70af33d 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java @@ -14,7 +14,6 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; -import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpServerSession; @@ -26,6 +25,7 @@ import reactor.core.publisher.Mono; import reactor.test.StepVerifier; +import static io.modelcontextprotocol.util.McpJsonMapperUtils.JSON_MAPPER; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; @@ -50,8 +50,6 @@ class StdioServerTransportProviderTests { private StdioServerTransportProvider transportProvider; - private ObjectMapper objectMapper; - private McpServerSession.Factory sessionFactory; private McpServerSession mockSession; @@ -64,8 +62,6 @@ void setUp() { System.setOut(testOutPrintStream); System.setErr(testOutPrintStream); - objectMapper = new ObjectMapper(); - // Create mocks for session factory and session mockSession = mock(McpServerSession.class); sessionFactory = mock(McpServerSession.Factory.class); @@ -75,7 +71,7 @@ void setUp() { when(mockSession.closeGracefully()).thenReturn(Mono.empty()); when(mockSession.sendNotification(any(), any())).thenReturn(Mono.empty()); - transportProvider = new StdioServerTransportProvider(objectMapper, System.in, testOutPrintStream); + transportProvider = new StdioServerTransportProvider(JSON_MAPPER, System.in, testOutPrintStream); } @AfterEach @@ -105,7 +101,7 @@ void shouldHandleIncomingMessages() throws Exception { String jsonMessage = "{\"jsonrpc\":\"2.0\",\"method\":\"test\",\"params\":{},\"id\":1}\n"; InputStream stream = new ByteArrayInputStream(jsonMessage.getBytes(StandardCharsets.UTF_8)); - transportProvider = new StdioServerTransportProvider(objectMapper, stream, System.out); + transportProvider = new StdioServerTransportProvider(JSON_MAPPER, stream, System.out); // Set up a real session to capture the message AtomicReference capturedMessage = new AtomicReference<>(); CountDownLatch messageLatch = new CountDownLatch(1); @@ -185,7 +181,7 @@ void shouldHandleMultipleCloseGracefullyCalls() { @Test void shouldHandleNotificationBeforeSessionFactoryIsSet() { - transportProvider = new StdioServerTransportProvider(objectMapper); + transportProvider = new StdioServerTransportProvider(JSON_MAPPER); // Send notification before setting session factory StepVerifier.create(transportProvider.notifyClients("testNotification", Map.of("key", "value"))) .verifyErrorSatisfies(error -> { @@ -200,7 +196,7 @@ void shouldHandleInvalidJsonMessage() throws Exception { String jsonMessage = "{invalid json}\n"; InputStream stream = new ByteArrayInputStream(jsonMessage.getBytes(StandardCharsets.UTF_8)); - transportProvider = new StdioServerTransportProvider(objectMapper, stream, testOutPrintStream); + transportProvider = new StdioServerTransportProvider(JSON_MAPPER, stream, testOutPrintStream); // Set up a session factory transportProvider.setSessionFactory(sessionFactory); diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/TomcatTestUtil.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/TomcatTestUtil.java similarity index 70% rename from mcp/src/test/java/io/modelcontextprotocol/server/transport/TomcatTestUtil.java rename to mcp-core/src/test/java/io/modelcontextprotocol/server/transport/TomcatTestUtil.java index f61cdc413..490e29838 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/TomcatTestUtil.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/TomcatTestUtil.java @@ -1,18 +1,23 @@ /* * Copyright 2025 - 2025 the original author or authors. */ + package io.modelcontextprotocol.server.transport; import java.io.IOException; import java.net.InetSocketAddress; import java.net.ServerSocket; +import jakarta.servlet.Filter; import jakarta.servlet.Servlet; import org.apache.catalina.Context; import org.apache.catalina.startup.Tomcat; +import org.apache.tomcat.util.descriptor.web.FilterDef; +import org.apache.tomcat.util.descriptor.web.FilterMap; /** * @author Christian Tzolov + * @author Daniel Garnier-Moiroux */ public class TomcatTestUtil { @@ -20,7 +25,8 @@ public class TomcatTestUtil { // Prevent instantiation } - public static Tomcat createTomcatServer(String contextPath, int port, Servlet servlet) { + public static Tomcat createTomcatServer(String contextPath, int port, Servlet servlet, + Filter... additionalFilters) { var tomcat = new Tomcat(); tomcat.setPort(port); @@ -39,6 +45,18 @@ public static Tomcat createTomcatServer(String contextPath, int port, Servlet se context.addChild(wrapper); context.addServletMappingDecoded("/*", "mcpServlet"); + for (var filter : additionalFilters) { + var filterDef = new FilterDef(); + filterDef.setFilter(filter); + filterDef.setFilterName(McpTestRequestRecordingServletFilter.class.getSimpleName()); + context.addFilterDef(filterDef); + + var filterMap = new FilterMap(); + filterMap.setFilterName(McpTestRequestRecordingServletFilter.class.getSimpleName()); + filterMap.addURLPattern("/*"); + context.addFilterMap(filterMap); + } + var connector = tomcat.getConnector(); connector.setAsyncTimeout(3000); diff --git a/mcp/src/test/java/io/modelcontextprotocol/spec/ArgumentException.java b/mcp-core/src/test/java/io/modelcontextprotocol/spec/ArgumentException.java similarity index 54% rename from mcp/src/test/java/io/modelcontextprotocol/spec/ArgumentException.java rename to mcp-core/src/test/java/io/modelcontextprotocol/spec/ArgumentException.java index ba4e851f9..a0bd568ef 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/spec/ArgumentException.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/spec/ArgumentException.java @@ -1,3 +1,7 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + package io.modelcontextprotocol.spec; public class ArgumentException { diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/spec/CompleteCompletionSerializationTest.java b/mcp-core/src/test/java/io/modelcontextprotocol/spec/CompleteCompletionSerializationTest.java new file mode 100644 index 000000000..55f71fea4 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/spec/CompleteCompletionSerializationTest.java @@ -0,0 +1,28 @@ +package io.modelcontextprotocol.spec; + +import io.modelcontextprotocol.json.McpJsonMapper; +import org.junit.jupiter.api.Test; +import java.io.IOException; +import java.util.Collections; +import static org.junit.jupiter.api.Assertions.assertEquals; + +class CompleteCompletionSerializationTest { + + @Test + void codeCompletionSerialization() throws IOException { + McpJsonMapper jsonMapper = McpJsonMapper.getDefault(); + McpSchema.CompleteResult.CompleteCompletion codeComplete = new McpSchema.CompleteResult.CompleteCompletion( + Collections.emptyList(), 0, false); + String json = jsonMapper.writeValueAsString(codeComplete); + String expected = """ + {"values":[],"total":0,"hasMore":false}"""; + assertEquals(expected, json, json); + + McpSchema.CompleteResult completeResult = new McpSchema.CompleteResult(codeComplete); + json = jsonMapper.writeValueAsString(completeResult); + expected = """ + {"completion":{"values":[],"total":0,"hasMore":false}}"""; + assertEquals(expected, json, json); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/spec/JSONRPCRequestMcpValidationTest.java b/mcp-core/src/test/java/io/modelcontextprotocol/spec/JSONRPCRequestMcpValidationTest.java new file mode 100644 index 000000000..fbe17d464 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/spec/JSONRPCRequestMcpValidationTest.java @@ -0,0 +1,90 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.spec; + +import org.junit.jupiter.api.Test; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * Tests for MCP-specific validation of JSONRPCRequest ID requirements. + * + * @author Christian Tzolov + */ +public class JSONRPCRequestMcpValidationTest { + + @Test + public void testValidStringId() { + assertDoesNotThrow(() -> { + var request = new McpSchema.JSONRPCRequest("2.0", "test/method", "string-id", null); + assertEquals("string-id", request.id()); + }); + } + + @Test + public void testValidIntegerId() { + assertDoesNotThrow(() -> { + var request = new McpSchema.JSONRPCRequest("2.0", "test/method", 123, null); + assertEquals(123, request.id()); + }); + } + + @Test + public void testValidLongId() { + assertDoesNotThrow(() -> { + var request = new McpSchema.JSONRPCRequest("2.0", "test/method", 123L, null); + assertEquals(123L, request.id()); + }); + } + + @Test + public void testNullIdThrowsException() { + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> { + new McpSchema.JSONRPCRequest("2.0", "test/method", null, null); + }); + + assertTrue(exception.getMessage().contains("MCP requests MUST include an ID")); + assertTrue(exception.getMessage().contains("null IDs are not allowed")); + } + + @Test + public void testDoubleIdTypeThrowsException() { + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> { + new McpSchema.JSONRPCRequest("2.0", "test/method", 123.45, null); + }); + + assertTrue(exception.getMessage().contains("MCP requests MUST have an ID that is either a string or integer")); + } + + @Test + public void testBooleanIdThrowsException() { + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> { + new McpSchema.JSONRPCRequest("2.0", "test/method", true, null); + }); + + assertTrue(exception.getMessage().contains("MCP requests MUST have an ID that is either a string or integer")); + } + + @Test + public void testArrayIdThrowsException() { + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> { + new McpSchema.JSONRPCRequest("2.0", "test/method", new String[] { "array" }, null); + }); + + assertTrue(exception.getMessage().contains("MCP requests MUST have an ID that is either a string or integer")); + } + + @Test + public void testObjectIdThrowsException() { + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> { + new McpSchema.JSONRPCRequest("2.0", "test/method", new Object(), null); + }); + + assertTrue(exception.getMessage().contains("MCP requests MUST have an ID that is either a string or integer")); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java new file mode 100644 index 000000000..3de06f503 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java @@ -0,0 +1,313 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.spec; + +import java.time.Duration; +import java.util.Map; +import java.util.function.Function; + +import io.modelcontextprotocol.MockMcpClientTransport; +import io.modelcontextprotocol.json.TypeRef; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.test.StepVerifier; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Test suite for {@link McpClientSession} that verifies its JSON-RPC message handling, + * request-response correlation, and notification processing. + * + * @author Christian Tzolov + */ +class McpClientSessionTests { + + private static final Logger logger = LoggerFactory.getLogger(McpClientSessionTests.class); + + private static final Duration TIMEOUT = Duration.ofSeconds(5); + + private static final String TEST_METHOD = "test.method"; + + private static final String TEST_NOTIFICATION = "test.notification"; + + private static final String ECHO_METHOD = "echo"; + + TypeRef responseType = new TypeRef<>() { + }; + + @Test + void testSendRequest() { + String testParam = "test parameter"; + String responseData = "test response"; + + var transport = new MockMcpClientTransport(); + var session = new McpClientSession(TIMEOUT, transport, Map.of(), + Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> logger.info("Status update: {}", params))), + Function.identity()); + + // Create a Mono that will emit the response after the request is sent + Mono responseMono = session.sendRequest(TEST_METHOD, testParam, responseType); + // Verify response handling + StepVerifier.create(responseMono).then(() -> { + McpSchema.JSONRPCRequest request = transport.getLastSentMessageAsRequest(); + transport.simulateIncomingMessage( + new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), responseData, null)); + }).consumeNextWith(response -> { + // Verify the request was sent + McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessageAsRequest(); + assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCRequest.class); + McpSchema.JSONRPCRequest request = (McpSchema.JSONRPCRequest) sentMessage; + assertThat(request.method()).isEqualTo(TEST_METHOD); + assertThat(request.params()).isEqualTo(testParam); + assertThat(response).isEqualTo(responseData); + }).verifyComplete(); + + session.close(); + } + + @Test + void testSendRequestWithError() { + var transport = new MockMcpClientTransport(); + var session = new McpClientSession(TIMEOUT, transport, Map.of(), + Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> logger.info("Status update: {}", params))), + Function.identity()); + + Mono responseMono = session.sendRequest(TEST_METHOD, "test", responseType); + + // Verify error handling + StepVerifier.create(responseMono).then(() -> { + McpSchema.JSONRPCRequest request = transport.getLastSentMessageAsRequest(); + // Simulate error response + McpSchema.JSONRPCResponse.JSONRPCError error = new McpSchema.JSONRPCResponse.JSONRPCError( + McpSchema.ErrorCodes.METHOD_NOT_FOUND, "Method not found", null); + transport.simulateIncomingMessage( + new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, error)); + }).expectError(McpError.class).verify(); + + session.close(); + } + + @Test + void testRequestTimeout() { + var transport = new MockMcpClientTransport(); + var session = new McpClientSession(TIMEOUT, transport, Map.of(), + Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> logger.info("Status update: {}", params))), + Function.identity()); + + Mono responseMono = session.sendRequest(TEST_METHOD, "test", responseType); + + // Verify timeout + StepVerifier.create(responseMono) + .expectError(java.util.concurrent.TimeoutException.class) + .verify(TIMEOUT.plusSeconds(1)); + + session.close(); + } + + @Test + void testSendNotification() { + var transport = new MockMcpClientTransport(); + var session = new McpClientSession(TIMEOUT, transport, Map.of(), + Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> logger.info("Status update: {}", params))), + Function.identity()); + + Map params = Map.of("key", "value"); + Mono notificationMono = session.sendNotification(TEST_NOTIFICATION, params); + + // Verify notification was sent + StepVerifier.create(notificationMono).consumeSubscriptionWith(response -> { + McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessage(); + assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCNotification.class); + McpSchema.JSONRPCNotification notification = (McpSchema.JSONRPCNotification) sentMessage; + assertThat(notification.method()).isEqualTo(TEST_NOTIFICATION); + assertThat(notification.params()).isEqualTo(params); + }).verifyComplete(); + + session.close(); + } + + @Test + void testRequestHandling() { + String echoMessage = "Hello MCP!"; + Map> requestHandlers = Map.of(ECHO_METHOD, + params -> Mono.just(params)); + var transport = new MockMcpClientTransport(); + var session = new McpClientSession(TIMEOUT, transport, requestHandlers, Map.of(), Function.identity()); + + // Simulate incoming request + McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, ECHO_METHOD, + "test-id", echoMessage); + transport.simulateIncomingMessage(request); + + // Verify response + McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessage(); + assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class); + McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; + assertThat(response.result()).isEqualTo(echoMessage); + assertThat(response.error()).isNull(); + + session.close(); + } + + @Test + void testNotificationHandling() { + Sinks.One receivedParams = Sinks.one(); + + var transport = new MockMcpClientTransport(); + var session = new McpClientSession(TIMEOUT, transport, Map.of(), + Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> receivedParams.tryEmitValue(params))), + Function.identity()); + + // Simulate incoming notification from the server + Map notificationParams = Map.of("status", "ready"); + + McpSchema.JSONRPCNotification notification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, + TEST_NOTIFICATION, notificationParams); + + transport.simulateIncomingMessage(notification); + + // Verify handler was called + assertThat(receivedParams.asMono().block(Duration.ofSeconds(1))).isEqualTo(notificationParams); + + session.close(); + } + + @Test + void testUnknownMethodHandling() { + + var transport = new MockMcpClientTransport(); + var session = new McpClientSession(TIMEOUT, transport, Map.of(), + Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> logger.info("Status update: {}", params))), + Function.identity()); + + // Simulate incoming request for unknown method + McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, "unknown.method", + "test-id", null); + transport.simulateIncomingMessage(request); + + // Verify error response + McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessage(); + assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class); + McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; + assertThat(response.error()).isNotNull(); + assertThat(response.error().code()).isEqualTo(McpSchema.ErrorCodes.METHOD_NOT_FOUND); + + session.close(); + } + + @Test + void testRequestHandlerThrowsMcpErrorWithJsonRpcError() { + // Setup: Create a request handler that throws McpError with custom error code and + // data + String testMethod = "test.customError"; + Map errorData = Map.of("customField", "customValue"); + McpClientSession.RequestHandler failingHandler = params -> Mono + .error(McpError.builder(123).message("Custom error message").data(errorData).build()); + + var transport = new MockMcpClientTransport(); + var session = new McpClientSession(TIMEOUT, transport, Map.of(testMethod, failingHandler), Map.of(), + Function.identity()); + + // Simulate incoming request that will trigger the error + McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, testMethod, + "test-id", null); + transport.simulateIncomingMessage(request); + + // Verify: The response should contain the custom error from McpError + McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessage(); + assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class); + McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; + assertThat(response.error()).isNotNull(); + assertThat(response.error().code()).isEqualTo(123); + assertThat(response.error().message()).isEqualTo("Custom error message"); + assertThat(response.error().data()).isEqualTo(errorData); + + session.close(); + } + + @Test + void testRequestHandlerThrowsGenericException() { + // Setup: Create a request handler that throws a generic RuntimeException + String testMethod = "test.genericError"; + RuntimeException exception = new RuntimeException("Something went wrong"); + McpClientSession.RequestHandler failingHandler = params -> Mono.error(exception); + + var transport = new MockMcpClientTransport(); + var session = new McpClientSession(TIMEOUT, transport, Map.of(testMethod, failingHandler), Map.of(), + Function.identity()); + + // Simulate incoming request that will trigger the error + McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, testMethod, + "test-id", null); + transport.simulateIncomingMessage(request); + + // Verify: The response should contain INTERNAL_ERROR with aggregated exception + // messages in data field + McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessage(); + assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class); + McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; + assertThat(response.error()).isNotNull(); + assertThat(response.error().code()).isEqualTo(McpSchema.ErrorCodes.INTERNAL_ERROR); + assertThat(response.error().message()).isEqualTo("Something went wrong"); + // Verify data field contains aggregated exception messages + assertThat(response.error().data()).isNotNull(); + assertThat(response.error().data().toString()).contains("RuntimeException"); + assertThat(response.error().data().toString()).contains("Something went wrong"); + + session.close(); + } + + @Test + void testRequestHandlerThrowsExceptionWithCause() { + // Setup: Create a request handler that throws an exception with a cause chain + String testMethod = "test.chainedError"; + RuntimeException rootCause = new IllegalArgumentException("Root cause message"); + RuntimeException middleCause = new IllegalStateException("Middle cause message", rootCause); + RuntimeException topException = new RuntimeException("Top level message", middleCause); + McpClientSession.RequestHandler failingHandler = params -> Mono.error(topException); + + var transport = new MockMcpClientTransport(); + var session = new McpClientSession(TIMEOUT, transport, Map.of(testMethod, failingHandler), Map.of(), + Function.identity()); + + // Simulate incoming request that will trigger the error + McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, testMethod, + "test-id", null); + transport.simulateIncomingMessage(request); + + // Verify: The response should contain INTERNAL_ERROR with full exception chain + // in data field + McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessage(); + assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class); + McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; + assertThat(response.error()).isNotNull(); + assertThat(response.error().code()).isEqualTo(McpSchema.ErrorCodes.INTERNAL_ERROR); + assertThat(response.error().message()).isEqualTo("Top level message"); + // Verify data field contains the full exception chain + String dataString = response.error().data().toString(); + assertThat(dataString).contains("RuntimeException"); + assertThat(dataString).contains("Top level message"); + assertThat(dataString).contains("IllegalStateException"); + assertThat(dataString).contains("Middle cause message"); + assertThat(dataString).contains("IllegalArgumentException"); + assertThat(dataString).contains("Root cause message"); + + session.close(); + } + + @Test + void testGracefulShutdown() { + var transport = new MockMcpClientTransport(); + var session = new McpClientSession(TIMEOUT, transport, Map.of(), + Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> logger.info("Status update: {}", params))), + Function.identity()); + + StepVerifier.create(session.closeGracefully()).verifyComplete(); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/spec/McpErrorTest.java b/mcp-core/src/test/java/io/modelcontextprotocol/spec/McpErrorTest.java new file mode 100644 index 000000000..0978ffe0b --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/spec/McpErrorTest.java @@ -0,0 +1,22 @@ +package io.modelcontextprotocol.spec; + +import org.junit.jupiter.api.Test; + +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertEquals; + +class McpErrorTest { + + @Test + void testNotFound() { + String uri = "file:///nonexistent.txt"; + McpError mcpError = McpError.RESOURCE_NOT_FOUND.apply(uri); + assertNotNull(mcpError.getJsonRpcError()); + assertEquals(-32002, mcpError.getJsonRpcError().code()); + assertEquals("Resource not found", mcpError.getJsonRpcError().message()); + assertEquals(Map.of("uri", uri), mcpError.getJsonRpcError().data()); + } + +} \ No newline at end of file diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java new file mode 100644 index 000000000..6b0004cb9 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java @@ -0,0 +1,1765 @@ +/* +* Copyright 2025 - 2025 the original author or authors. +*/ + +package io.modelcontextprotocol.spec; + +import static io.modelcontextprotocol.util.McpJsonMapperUtils.JSON_MAPPER; +import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson; +import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.jupiter.api.Test; + +import com.fasterxml.jackson.databind.exc.InvalidTypeIdException; + +import io.modelcontextprotocol.spec.McpSchema.TextResourceContents; +import net.javacrumbs.jsonunit.core.Option; + +/** + * @author Christian Tzolov + * @author Anurag Pant + */ +public class McpSchemaTests { + + // Content Types Tests + + @Test + void testTextContent() throws Exception { + McpSchema.TextContent test = new McpSchema.TextContent("XXX"); + String value = JSON_MAPPER.writeValueAsString(test); + + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"type":"text","text":"XXX"}""")); + } + + @Test + void testTextContentDeserialization() throws Exception { + McpSchema.TextContent textContent = JSON_MAPPER.readValue(""" + {"type":"text","text":"XXX","_meta":{"metaKey":"metaValue"}}""", McpSchema.TextContent.class); + + assertThat(textContent).isNotNull(); + assertThat(textContent.type()).isEqualTo("text"); + assertThat(textContent.text()).isEqualTo("XXX"); + assertThat(textContent.meta()).containsKey("metaKey"); + } + + @Test + void testContentDeserializationWrongType() throws Exception { + + assertThatThrownBy(() -> JSON_MAPPER.readValue(""" + {"type":"WRONG","text":"XXX"}""", McpSchema.TextContent.class)) + .isInstanceOf(InvalidTypeIdException.class) + .hasMessageContaining( + "Could not resolve type id 'WRONG' as a subtype of `io.modelcontextprotocol.spec.McpSchema$TextContent`: known type ids = [audio, image, resource, resource_link, text]"); + } + + @Test + void testImageContent() throws Exception { + McpSchema.ImageContent test = new McpSchema.ImageContent(null, null, "base64encodeddata", "image/png"); + String value = JSON_MAPPER.writeValueAsString(test); + + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"type":"image","data":"base64encodeddata","mimeType":"image/png"}""")); + } + + @Test + void testImageContentDeserialization() throws Exception { + McpSchema.ImageContent imageContent = JSON_MAPPER.readValue(""" + {"type":"image","data":"base64encodeddata","mimeType":"image/png","_meta":{"metaKey":"metaValue"}}""", + McpSchema.ImageContent.class); + assertThat(imageContent).isNotNull(); + assertThat(imageContent.type()).isEqualTo("image"); + assertThat(imageContent.data()).isEqualTo("base64encodeddata"); + assertThat(imageContent.mimeType()).isEqualTo("image/png"); + assertThat(imageContent.meta()).containsKey("metaKey"); + } + + @Test + void testAudioContent() throws Exception { + McpSchema.AudioContent audioContent = new McpSchema.AudioContent(null, "base64encodeddata", "audio/wav"); + String value = JSON_MAPPER.writeValueAsString(audioContent); + + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"type":"audio","data":"base64encodeddata","mimeType":"audio/wav"}""")); + } + + @Test + void testAudioContentDeserialization() throws Exception { + McpSchema.AudioContent audioContent = JSON_MAPPER.readValue(""" + {"type":"audio","data":"base64encodeddata","mimeType":"audio/wav","_meta":{"metaKey":"metaValue"}}""", + McpSchema.AudioContent.class); + assertThat(audioContent).isNotNull(); + assertThat(audioContent.type()).isEqualTo("audio"); + assertThat(audioContent.data()).isEqualTo("base64encodeddata"); + assertThat(audioContent.mimeType()).isEqualTo("audio/wav"); + assertThat(audioContent.meta()).containsKey("metaKey"); + } + + @Test + void testCreateMessageRequestWithMeta() throws Exception { + McpSchema.TextContent content = new McpSchema.TextContent("User message"); + McpSchema.SamplingMessage message = new McpSchema.SamplingMessage(McpSchema.Role.USER, content); + McpSchema.ModelHint hint = new McpSchema.ModelHint("gpt-4"); + McpSchema.ModelPreferences preferences = new McpSchema.ModelPreferences(Collections.singletonList(hint), 0.3, + 0.7, 0.9); + + Map metadata = new HashMap<>(); + metadata.put("session", "test-session"); + + Map meta = new HashMap<>(); + meta.put("progressToken", "create-message-token-456"); + + McpSchema.CreateMessageRequest request = McpSchema.CreateMessageRequest.builder() + .messages(Collections.singletonList(message)) + .modelPreferences(preferences) + .systemPrompt("You are a helpful assistant") + .includeContext(McpSchema.CreateMessageRequest.ContextInclusionStrategy.THIS_SERVER) + .temperature(0.7) + .maxTokens(1000) + .stopSequences(Arrays.asList("STOP", "END")) + .metadata(metadata) + .meta(meta) + .build(); + + String value = JSON_MAPPER.writeValueAsString(request); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .containsEntry("_meta", Map.of("progressToken", "create-message-token-456")); + + // Test Request interface methods + assertThat(request.meta()).isEqualTo(meta); + assertThat(request.progressToken()).isEqualTo("create-message-token-456"); + } + + @Test + void testEmbeddedResource() throws Exception { + McpSchema.TextResourceContents resourceContents = new McpSchema.TextResourceContents("resource://test", + "text/plain", "Sample resource content"); + + McpSchema.EmbeddedResource test = new McpSchema.EmbeddedResource(null, null, resourceContents); + + String value = JSON_MAPPER.writeValueAsString(test); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"type":"resource","resource":{"uri":"resource://test","mimeType":"text/plain","text":"Sample resource content"}}""")); + } + + @Test + void testEmbeddedResourceDeserialization() throws Exception { + McpSchema.EmbeddedResource embeddedResource = JSON_MAPPER.readValue( + """ + {"type":"resource","resource":{"uri":"resource://test","mimeType":"text/plain","text":"Sample resource content"},"_meta":{"metaKey":"metaValue"}}""", + McpSchema.EmbeddedResource.class); + assertThat(embeddedResource).isNotNull(); + assertThat(embeddedResource.type()).isEqualTo("resource"); + assertThat(embeddedResource.resource()).isNotNull(); + assertThat(embeddedResource.resource().uri()).isEqualTo("resource://test"); + assertThat(embeddedResource.resource().mimeType()).isEqualTo("text/plain"); + assertThat(((TextResourceContents) embeddedResource.resource()).text()).isEqualTo("Sample resource content"); + assertThat(embeddedResource.meta()).containsKey("metaKey"); + } + + @Test + void testEmbeddedResourceWithBlobContents() throws Exception { + McpSchema.BlobResourceContents resourceContents = new McpSchema.BlobResourceContents("resource://test", + "application/octet-stream", "base64encodedblob"); + + McpSchema.EmbeddedResource test = new McpSchema.EmbeddedResource(null, null, resourceContents); + + String value = JSON_MAPPER.writeValueAsString(test); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"type":"resource","resource":{"uri":"resource://test","mimeType":"application/octet-stream","blob":"base64encodedblob"}}""")); + } + + @Test + void testEmbeddedResourceWithBlobContentsDeserialization() throws Exception { + McpSchema.EmbeddedResource embeddedResource = JSON_MAPPER.readValue( + """ + {"type":"resource","resource":{"uri":"resource://test","mimeType":"application/octet-stream","blob":"base64encodedblob","_meta":{"metaKey":"metaValue"}}}""", + McpSchema.EmbeddedResource.class); + assertThat(embeddedResource).isNotNull(); + assertThat(embeddedResource.type()).isEqualTo("resource"); + assertThat(embeddedResource.resource()).isNotNull(); + assertThat(embeddedResource.resource().uri()).isEqualTo("resource://test"); + assertThat(embeddedResource.resource().mimeType()).isEqualTo("application/octet-stream"); + assertThat(((McpSchema.BlobResourceContents) embeddedResource.resource()).blob()) + .isEqualTo("base64encodedblob"); + assertThat(((McpSchema.BlobResourceContents) embeddedResource.resource()).meta()).containsKey("metaKey"); + } + + @Test + void testResourceLink() throws Exception { + McpSchema.ResourceLink resourceLink = new McpSchema.ResourceLink("main.rs", "Main file", + "file:///project/src/main.rs", "Primary application entry point", "text/x-rust", null, null, + Map.of("metaKey", "metaValue")); + String value = JSON_MAPPER.writeValueAsString(resourceLink); + + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"type":"resource_link","name":"main.rs","title":"Main file","uri":"file:///project/src/main.rs","description":"Primary application entry point","mimeType":"text/x-rust","_meta":{"metaKey":"metaValue"}}""")); + } + + @Test + void testResourceLinkDeserialization() throws Exception { + McpSchema.ResourceLink resourceLink = JSON_MAPPER.readValue( + """ + {"type":"resource_link","name":"main.rs","uri":"file:///project/src/main.rs","description":"Primary application entry point","mimeType":"text/x-rust","_meta":{"metaKey":"metaValue"}}""", + McpSchema.ResourceLink.class); + assertThat(resourceLink).isNotNull(); + assertThat(resourceLink.type()).isEqualTo("resource_link"); + assertThat(resourceLink.name()).isEqualTo("main.rs"); + assertThat(resourceLink.uri()).isEqualTo("file:///project/src/main.rs"); + assertThat(resourceLink.description()).isEqualTo("Primary application entry point"); + assertThat(resourceLink.mimeType()).isEqualTo("text/x-rust"); + assertThat(resourceLink.meta()).containsEntry("metaKey", "metaValue"); + } + + // JSON-RPC Message Types Tests + + @Test + void testJSONRPCRequest() throws Exception { + Map params = new HashMap<>(); + params.put("key", "value"); + + McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, "method_name", 1, + params); + + String value = JSON_MAPPER.writeValueAsString(request); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"jsonrpc":"2.0","method":"method_name","id":1,"params":{"key":"value"}}""")); + } + + @Test + void testJSONRPCNotification() throws Exception { + Map params = new HashMap<>(); + params.put("key", "value"); + + McpSchema.JSONRPCNotification notification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, + "notification_method", params); + + String value = JSON_MAPPER.writeValueAsString(notification); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"jsonrpc":"2.0","method":"notification_method","params":{"key":"value"}}""")); + } + + @Test + void testJSONRPCResponse() throws Exception { + Map result = new HashMap<>(); + result.put("result_key", "result_value"); + + McpSchema.JSONRPCResponse response = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, 1, result, null); + + String value = JSON_MAPPER.writeValueAsString(response); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"jsonrpc":"2.0","id":1,"result":{"result_key":"result_value"}}""")); + } + + @Test + void testJSONRPCResponseWithError() throws Exception { + McpSchema.JSONRPCResponse.JSONRPCError error = new McpSchema.JSONRPCResponse.JSONRPCError( + McpSchema.ErrorCodes.INVALID_REQUEST, "Invalid request", null); + + McpSchema.JSONRPCResponse response = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, 1, null, error); + + String value = JSON_MAPPER.writeValueAsString(response); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"jsonrpc":"2.0","id":1,"error":{"code":-32600,"message":"Invalid request"}}""")); + } + + // Initialization Tests + + @Test + void testInitializeRequest() throws Exception { + McpSchema.ClientCapabilities capabilities = McpSchema.ClientCapabilities.builder() + .roots(true) + .sampling() + .build(); + + McpSchema.Implementation clientInfo = new McpSchema.Implementation("test-client", "1.0.0"); + Map meta = Map.of("metaKey", "metaValue"); + + McpSchema.InitializeRequest request = new McpSchema.InitializeRequest(ProtocolVersions.MCP_2024_11_05, + capabilities, clientInfo, meta); + + String value = JSON_MAPPER.writeValueAsString(request); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"protocolVersion":"2024-11-05","capabilities":{"roots":{"listChanged":true},"sampling":{}},"clientInfo":{"name":"test-client","version":"1.0.0"},"_meta":{"metaKey":"metaValue"}}""")); + } + + @Test + void testInitializeResult() throws Exception { + McpSchema.ServerCapabilities capabilities = McpSchema.ServerCapabilities.builder() + .logging() + .prompts(true) + .resources(true, true) + .tools(true) + .build(); + + McpSchema.Implementation serverInfo = new McpSchema.Implementation("test-server", "1.0.0"); + + McpSchema.InitializeResult result = new McpSchema.InitializeResult(ProtocolVersions.MCP_2024_11_05, + capabilities, serverInfo, "Server initialized successfully"); + + String value = JSON_MAPPER.writeValueAsString(result); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"protocolVersion":"2024-11-05","capabilities":{"logging":{},"prompts":{"listChanged":true},"resources":{"subscribe":true,"listChanged":true},"tools":{"listChanged":true}},"serverInfo":{"name":"test-server","version":"1.0.0"},"instructions":"Server initialized successfully"}""")); + } + + // Resource Tests + + @Test + void testResource() throws Exception { + McpSchema.Annotations annotations = new McpSchema.Annotations( + Arrays.asList(McpSchema.Role.USER, McpSchema.Role.ASSISTANT), 0.8); + + McpSchema.Resource resource = new McpSchema.Resource("resource://test", "Test Resource", "A test resource", + "text/plain", annotations); + + String value = JSON_MAPPER.writeValueAsString(resource); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"uri":"resource://test","name":"Test Resource","description":"A test resource","mimeType":"text/plain","annotations":{"audience":["user","assistant"],"priority":0.8}}""")); + } + + @Test + void testResourceBuilder() throws Exception { + McpSchema.Annotations annotations = new McpSchema.Annotations( + Arrays.asList(McpSchema.Role.USER, McpSchema.Role.ASSISTANT), 0.8); + + McpSchema.Resource resource = McpSchema.Resource.builder() + .uri("resource://test") + .name("Test Resource") + .description("A test resource") + .mimeType("text/plain") + .size(256L) + .annotations(annotations) + .meta(Map.of("metaKey", "metaValue")) + .build(); + + String value = JSON_MAPPER.writeValueAsString(resource); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"uri":"resource://test","name":"Test Resource","description":"A test resource","mimeType":"text/plain","size":256,"annotations":{"audience":["user","assistant"],"priority":0.8},"_meta":{"metaKey":"metaValue"}}""")); + } + + @Test + void testResourceBuilderUriRequired() { + McpSchema.Annotations annotations = new McpSchema.Annotations( + Arrays.asList(McpSchema.Role.USER, McpSchema.Role.ASSISTANT), 0.8); + + McpSchema.Resource.Builder resourceBuilder = McpSchema.Resource.builder() + .name("Test Resource") + .description("A test resource") + .mimeType("text/plain") + .size(256L) + .annotations(annotations); + + assertThatThrownBy(resourceBuilder::build).isInstanceOf(java.lang.IllegalArgumentException.class); + } + + @Test + void testResourceBuilderNameRequired() { + McpSchema.Annotations annotations = new McpSchema.Annotations( + Arrays.asList(McpSchema.Role.USER, McpSchema.Role.ASSISTANT), 0.8); + + McpSchema.Resource.Builder resourceBuilder = McpSchema.Resource.builder() + .uri("resource://test") + .description("A test resource") + .mimeType("text/plain") + .size(256L) + .annotations(annotations); + + assertThatThrownBy(resourceBuilder::build).isInstanceOf(java.lang.IllegalArgumentException.class); + } + + @Test + void testResourceTemplate() throws Exception { + McpSchema.Annotations annotations = new McpSchema.Annotations(Arrays.asList(McpSchema.Role.USER), 0.5); + Map meta = Map.of("metaKey", "metaValue"); + + McpSchema.ResourceTemplate template = new McpSchema.ResourceTemplate("resource://{param}/test", "Test Template", + "Test Template", "A test resource template", "text/plain", annotations, meta); + + String value = JSON_MAPPER.writeValueAsString(template); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"uriTemplate":"resource://{param}/test","name":"Test Template","title":"Test Template","description":"A test resource template","mimeType":"text/plain","annotations":{"audience":["user"],"priority":0.5},"_meta":{"metaKey":"metaValue"}}""")); + } + + @Test + void testListResourcesResult() throws Exception { + McpSchema.Resource resource1 = new McpSchema.Resource("resource://test1", "Test Resource 1", + "First test resource", "text/plain", null); + + McpSchema.Resource resource2 = new McpSchema.Resource("resource://test2", "Test Resource 2", + "Second test resource", "application/json", null); + + Map meta = Map.of("metaKey", "metaValue"); + + McpSchema.ListResourcesResult result = new McpSchema.ListResourcesResult(Arrays.asList(resource1, resource2), + "next-cursor", meta); + + String value = JSON_MAPPER.writeValueAsString(result); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"resources":[{"uri":"resource://test1","name":"Test Resource 1","description":"First test resource","mimeType":"text/plain"},{"uri":"resource://test2","name":"Test Resource 2","description":"Second test resource","mimeType":"application/json"}],"nextCursor":"next-cursor","_meta":{"metaKey":"metaValue"}}""")); + } + + @Test + void testListResourceTemplatesResult() throws Exception { + McpSchema.ResourceTemplate template1 = new McpSchema.ResourceTemplate("resource://{param}/test1", + "Test Template 1", "Test Template 1", "First test template", "text/plain", null); + + McpSchema.ResourceTemplate template2 = new McpSchema.ResourceTemplate("resource://{param}/test2", + "Test Template 2", "Test Template 2", "Second test template", "application/json", null); + + McpSchema.ListResourceTemplatesResult result = new McpSchema.ListResourceTemplatesResult( + Arrays.asList(template1, template2), "next-cursor"); + + String value = JSON_MAPPER.writeValueAsString(result); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"resourceTemplates":[{"uriTemplate":"resource://{param}/test1","name":"Test Template 1","title":"Test Template 1","description":"First test template","mimeType":"text/plain"},{"uriTemplate":"resource://{param}/test2","name":"Test Template 2","title":"Test Template 2","description":"Second test template","mimeType":"application/json"}],"nextCursor":"next-cursor"}""")); + } + + @Test + void testReadResourceRequest() throws Exception { + McpSchema.ReadResourceRequest request = new McpSchema.ReadResourceRequest("resource://test", + Map.of("metaKey", "metaValue")); + + String value = JSON_MAPPER.writeValueAsString(request); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"uri":"resource://test","_meta":{"metaKey":"metaValue"}}""")); + } + + @Test + void testReadResourceRequestWithMeta() throws Exception { + Map meta = new HashMap<>(); + meta.put("progressToken", "read-resource-token-123"); + + McpSchema.ReadResourceRequest request = new McpSchema.ReadResourceRequest("resource://test", meta); + + String value = JSON_MAPPER.writeValueAsString(request); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"uri":"resource://test","_meta":{"progressToken":"read-resource-token-123"}}""")); + + // Test Request interface methods + assertThat(request.meta()).isEqualTo(meta); + assertThat(request.progressToken()).isEqualTo("read-resource-token-123"); + } + + @Test + void testReadResourceRequestDeserialization() throws Exception { + McpSchema.ReadResourceRequest request = JSON_MAPPER.readValue(""" + {"uri":"resource://test","_meta":{"progressToken":"test-token"}}""", + McpSchema.ReadResourceRequest.class); + + assertThat(request.uri()).isEqualTo("resource://test"); + assertThat(request.meta()).containsEntry("progressToken", "test-token"); + assertThat(request.progressToken()).isEqualTo("test-token"); + } + + @Test + void testReadResourceResult() throws Exception { + McpSchema.TextResourceContents contents1 = new McpSchema.TextResourceContents("resource://test1", "text/plain", + "Sample text content"); + + McpSchema.BlobResourceContents contents2 = new McpSchema.BlobResourceContents("resource://test2", + "application/octet-stream", "base64encodedblob"); + + McpSchema.ReadResourceResult result = new McpSchema.ReadResourceResult(Arrays.asList(contents1, contents2), + Map.of("metaKey", "metaValue")); + + String value = JSON_MAPPER.writeValueAsString(result); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"contents":[{"uri":"resource://test1","mimeType":"text/plain","text":"Sample text content"},{"uri":"resource://test2","mimeType":"application/octet-stream","blob":"base64encodedblob"}],"_meta":{"metaKey":"metaValue"}}""")); + } + + // Prompt Tests + + @Test + void testPrompt() throws Exception { + McpSchema.PromptArgument arg1 = new McpSchema.PromptArgument("arg1", "First argument", "First argument", true); + + McpSchema.PromptArgument arg2 = new McpSchema.PromptArgument("arg2", "Second argument", "Second argument", + false); + + McpSchema.Prompt prompt = new McpSchema.Prompt("test-prompt", "Test Prompt", "A test prompt", + Arrays.asList(arg1, arg2), Map.of("metaKey", "metaValue")); + + String value = JSON_MAPPER.writeValueAsString(prompt); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"name":"test-prompt","title":"Test Prompt","description":"A test prompt","arguments":[{"name":"arg1","title":"First argument","description":"First argument","required":true},{"name":"arg2","title":"Second argument","description":"Second argument","required":false}],"_meta":{"metaKey":"metaValue"}}""")); + } + + @Test + void testPromptMessage() throws Exception { + McpSchema.TextContent content = new McpSchema.TextContent("Hello, world!"); + + McpSchema.PromptMessage message = new McpSchema.PromptMessage(McpSchema.Role.USER, content); + + String value = JSON_MAPPER.writeValueAsString(message); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"role":"user","content":{"type":"text","text":"Hello, world!"}}""")); + } + + @Test + void testListPromptsResult() throws Exception { + McpSchema.PromptArgument arg = new McpSchema.PromptArgument("arg", "Argument", "An argument", true); + + McpSchema.Prompt prompt1 = new McpSchema.Prompt("prompt1", "First prompt", "First prompt", + Collections.singletonList(arg)); + + McpSchema.Prompt prompt2 = new McpSchema.Prompt("prompt2", "Second prompt", "Second prompt", + Collections.emptyList()); + + McpSchema.ListPromptsResult result = new McpSchema.ListPromptsResult(Arrays.asList(prompt1, prompt2), + "next-cursor"); + + String value = JSON_MAPPER.writeValueAsString(result); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"prompts":[{"name":"prompt1","title":"First prompt","description":"First prompt","arguments":[{"name":"arg","title":"Argument","description":"An argument","required":true}]},{"name":"prompt2","title":"Second prompt","description":"Second prompt","arguments":[]}],"nextCursor":"next-cursor"}""")); + } + + @Test + void testGetPromptRequest() throws Exception { + Map arguments = new HashMap<>(); + arguments.put("arg1", "value1"); + arguments.put("arg2", 42); + + McpSchema.GetPromptRequest request = new McpSchema.GetPromptRequest("test-prompt", arguments); + + assertThat(JSON_MAPPER.readValue(""" + {"name":"test-prompt","arguments":{"arg1":"value1","arg2":42}}""", McpSchema.GetPromptRequest.class)) + .isEqualTo(request); + } + + @Test + void testGetPromptRequestWithMeta() throws Exception { + Map arguments = new HashMap<>(); + arguments.put("arg1", "value1"); + arguments.put("arg2", 42); + + Map meta = new HashMap<>(); + meta.put("progressToken", "token123"); + + McpSchema.GetPromptRequest request = new McpSchema.GetPromptRequest("test-prompt", arguments, meta); + + String value = JSON_MAPPER.writeValueAsString(request); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"name":"test-prompt","arguments":{"arg1":"value1","arg2":42},"_meta":{"progressToken":"token123"}}""")); + + // Test that it implements Request interface methods + assertThat(request.meta()).isEqualTo(meta); + assertThat(request.progressToken()).isEqualTo("token123"); + } + + @Test + void testGetPromptResult() throws Exception { + McpSchema.TextContent content1 = new McpSchema.TextContent("System message"); + McpSchema.TextContent content2 = new McpSchema.TextContent("User message"); + + McpSchema.PromptMessage message1 = new McpSchema.PromptMessage(McpSchema.Role.ASSISTANT, content1); + + McpSchema.PromptMessage message2 = new McpSchema.PromptMessage(McpSchema.Role.USER, content2); + + McpSchema.GetPromptResult result = new McpSchema.GetPromptResult("A test prompt result", + Arrays.asList(message1, message2)); + + String value = JSON_MAPPER.writeValueAsString(result); + + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"description":"A test prompt result","messages":[{"role":"assistant","content":{"type":"text","text":"System message"}},{"role":"user","content":{"type":"text","text":"User message"}}]}""")); + } + + // Tool Tests + + @Test + void testJsonSchema() throws Exception { + String schemaJson = """ + { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "address": { + "$ref": "#/$defs/Address" + } + }, + "required": ["name"], + "$defs": { + "Address": { + "type": "object", + "properties": { + "street": {"type": "string"}, + "city": {"type": "string"} + }, + "required": ["street", "city"] + } + } + } + """; + + // Deserialize the original string to a JsonSchema object + McpSchema.JsonSchema schema = JSON_MAPPER.readValue(schemaJson, McpSchema.JsonSchema.class); + + // Serialize the object back to a string + String serialized = JSON_MAPPER.writeValueAsString(schema); + + // Deserialize again + McpSchema.JsonSchema deserialized = JSON_MAPPER.readValue(serialized, McpSchema.JsonSchema.class); + + // Serialize one more time and compare with the first serialization + String serializedAgain = JSON_MAPPER.writeValueAsString(deserialized); + + // The two serialized strings should be the same + assertThatJson(serializedAgain).when(Option.IGNORING_ARRAY_ORDER).isEqualTo(json(serialized)); + } + + @Test + void testJsonSchemaWithDefinitions() throws Exception { + String schemaJson = """ + { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "address": { + "$ref": "#/definitions/Address" + } + }, + "required": ["name"], + "definitions": { + "Address": { + "type": "object", + "properties": { + "street": {"type": "string"}, + "city": {"type": "string"} + }, + "required": ["street", "city"] + } + } + } + """; + + // Deserialize the original string to a JsonSchema object + McpSchema.JsonSchema schema = JSON_MAPPER.readValue(schemaJson, McpSchema.JsonSchema.class); + + // Serialize the object back to a string + String serialized = JSON_MAPPER.writeValueAsString(schema); + + // Deserialize again + McpSchema.JsonSchema deserialized = JSON_MAPPER.readValue(serialized, McpSchema.JsonSchema.class); + + // Serialize one more time and compare with the first serialization + String serializedAgain = JSON_MAPPER.writeValueAsString(deserialized); + + // The two serialized strings should be the same + assertThatJson(serializedAgain).when(Option.IGNORING_ARRAY_ORDER).isEqualTo(json(serialized)); + } + + @Test + void testTool() throws Exception { + String schemaJson = """ + { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "value": { + "type": "number" + } + }, + "required": ["name"] + } + """; + + McpSchema.Tool tool = McpSchema.Tool.builder() + .name("test-tool") + .description("A test tool") + .inputSchema(JSON_MAPPER, schemaJson) + .build(); + + String value = JSON_MAPPER.writeValueAsString(tool); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"name":"test-tool","description":"A test tool","inputSchema":{"type":"object","properties":{"name":{"type":"string"},"value":{"type":"number"}},"required":["name"]}}""")); + } + + @Test + void testToolWithComplexSchema() throws Exception { + String complexSchemaJson = """ + { + "type": "object", + "$defs": { + "Address": { + "type": "object", + "properties": { + "street": {"type": "string"}, + "city": {"type": "string"} + }, + "required": ["street", "city"] + } + }, + "properties": { + "name": {"type": "string"}, + "shippingAddress": {"$ref": "#/$defs/Address"} + }, + "required": ["name", "shippingAddress"] + } + """; + + McpSchema.Tool tool = McpSchema.Tool.builder() + .name("addressTool") + .title("Handles addresses") + .inputSchema(JSON_MAPPER, complexSchemaJson) + .build(); + + // Serialize the tool to a string + String serialized = JSON_MAPPER.writeValueAsString(tool); + + // Deserialize back to a Tool object + McpSchema.Tool deserializedTool = JSON_MAPPER.readValue(serialized, McpSchema.Tool.class); + + // Serialize again and compare with first serialization + String serializedAgain = JSON_MAPPER.writeValueAsString(deserializedTool); + + // The two serialized strings should be the same + assertThatJson(serializedAgain).when(Option.IGNORING_ARRAY_ORDER).isEqualTo(json(serialized)); + + // Just verify the basic structure was preserved + assertThat(deserializedTool.inputSchema().defs()).isNotNull(); + assertThat(deserializedTool.inputSchema().defs()).containsKey("Address"); + } + + @Test + void testToolWithMeta() throws Exception { + String schemaJson = """ + { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "value": { + "type": "number" + } + }, + "required": ["name"] + } + """; + + McpSchema.JsonSchema schema = JSON_MAPPER.readValue(schemaJson, McpSchema.JsonSchema.class); + Map meta = Map.of("metaKey", "metaValue"); + + McpSchema.Tool tool = McpSchema.Tool.builder() + .name("addressTool") + .title("addressTool") + .description("Handles addresses") + .inputSchema(schema) + .meta(meta) + .build(); + + // Verify that meta value was preserved + assertThat(tool.meta()).isNotNull(); + assertThat(tool.meta()).containsKey("metaKey"); + } + + @Test + void testToolWithAnnotations() throws Exception { + String schemaJson = """ + { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "value": { + "type": "number" + } + }, + "required": ["name"] + } + """; + McpSchema.ToolAnnotations annotations = new McpSchema.ToolAnnotations("A test tool", false, false, false, false, + false); + + McpSchema.Tool tool = McpSchema.Tool.builder() + .name("test-tool") + .description("A test tool") + .inputSchema(JSON_MAPPER, schemaJson) + .annotations(annotations) + .build(); + + String value = JSON_MAPPER.writeValueAsString(tool); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + { + "name":"test-tool", + "description":"A test tool", + "inputSchema":{ + "type":"object", + "properties":{ + "name":{"type":"string"}, + "value":{"type":"number"} + }, + "required":["name"] + }, + "annotations":{ + "title":"A test tool", + "readOnlyHint":false, + "destructiveHint":false, + "idempotentHint":false, + "openWorldHint":false, + "returnDirect":false + } + } + """)); + } + + @Test + void testToolWithOutputSchema() throws Exception { + String inputSchemaJson = """ + { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "value": { + "type": "number" + } + }, + "required": ["name"] + } + """; + + String outputSchemaJson = """ + { + "type": "object", + "properties": { + "result": { + "type": "string" + }, + "status": { + "type": "string", + "enum": ["success", "error"] + } + }, + "required": ["result", "status"] + } + """; + + McpSchema.Tool tool = McpSchema.Tool.builder() + .name("test-tool") + .description("A test tool") + .inputSchema(JSON_MAPPER, inputSchemaJson) + .outputSchema(JSON_MAPPER, outputSchemaJson) + .build(); + + String value = JSON_MAPPER.writeValueAsString(tool); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + { + "name":"test-tool", + "description":"A test tool", + "inputSchema":{ + "type":"object", + "properties":{ + "name":{"type":"string"}, + "value":{"type":"number"} + }, + "required":["name"] + }, + "outputSchema":{ + "type":"object", + "properties":{ + "result":{"type":"string"}, + "status":{ + "type":"string", + "enum":["success","error"] + } + }, + "required":["result","status"] + } + } + """)); + } + + @Test + void testToolWithOutputSchemaAndAnnotations() throws Exception { + String inputSchemaJson = """ + { + "type": "object", + "properties": { + "name": { + "type": "string" + } + }, + "required": ["name"] + } + """; + + String outputSchemaJson = """ + { + "type": "object", + "properties": { + "result": { + "type": "string" + } + }, + "required": ["result"] + } + """; + + McpSchema.ToolAnnotations annotations = new McpSchema.ToolAnnotations("A test tool with output", true, false, + true, false, true); + + McpSchema.Tool tool = McpSchema.Tool.builder() + .name("test-tool") + .description("A test tool") + .inputSchema(JSON_MAPPER, inputSchemaJson) + .outputSchema(JSON_MAPPER, outputSchemaJson) + .annotations(annotations) + .build(); + + String value = JSON_MAPPER.writeValueAsString(tool); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + { + "name":"test-tool", + "description":"A test tool", + "inputSchema":{ + "type":"object", + "properties":{ + "name":{"type":"string"} + }, + "required":["name"] + }, + "outputSchema":{ + "type":"object", + "properties":{ + "result":{"type":"string"} + }, + "required":["result"] + }, + "annotations":{ + "title":"A test tool with output", + "readOnlyHint":true, + "destructiveHint":false, + "idempotentHint":true, + "openWorldHint":false, + "returnDirect":true + } + }""")); + } + + @Test + void testToolDeserialization() throws Exception { + String toolJson = """ + { + "name": "test-tool", + "description": "A test tool", + "inputSchema": { + "type": "object", + "properties": { + "name": {"type": "string"} + }, + "required": ["name"] + }, + "outputSchema": { + "type": "object", + "properties": { + "result": {"type": "string"} + }, + "required": ["result"] + }, + "annotations": { + "title": "Test Tool", + "readOnlyHint": true, + "destructiveHint": false, + "idempotentHint": true, + "openWorldHint": false, + "returnDirect": false + } + } + """; + + McpSchema.Tool tool = JSON_MAPPER.readValue(toolJson, McpSchema.Tool.class); + + assertThat(tool).isNotNull(); + assertThat(tool.name()).isEqualTo("test-tool"); + assertThat(tool.description()).isEqualTo("A test tool"); + assertThat(tool.inputSchema()).isNotNull(); + assertThat(tool.inputSchema().type()).isEqualTo("object"); + assertThat(tool.outputSchema()).isNotNull(); + assertThat(tool.outputSchema()).containsKey("type"); + assertThat(tool.outputSchema().get("type")).isEqualTo("object"); + assertThat(tool.annotations()).isNotNull(); + assertThat(tool.annotations().title()).isEqualTo("Test Tool"); + assertThat(tool.annotations().readOnlyHint()).isTrue(); + assertThat(tool.annotations().idempotentHint()).isTrue(); + assertThat(tool.annotations().destructiveHint()).isFalse(); + assertThat(tool.annotations().returnDirect()).isFalse(); + } + + @Test + void testToolDeserializationWithoutOutputSchema() throws Exception { + String toolJson = """ + { + "name": "test-tool", + "description": "A test tool", + "inputSchema": { + "type": "object", + "properties": { + "name": {"type": "string"} + }, + "required": ["name"] + } + } + """; + + McpSchema.Tool tool = JSON_MAPPER.readValue(toolJson, McpSchema.Tool.class); + + assertThat(tool).isNotNull(); + assertThat(tool.name()).isEqualTo("test-tool"); + assertThat(tool.description()).isEqualTo("A test tool"); + assertThat(tool.inputSchema()).isNotNull(); + assertThat(tool.outputSchema()).isNull(); + assertThat(tool.annotations()).isNull(); + } + + @Test + void testCallToolRequest() throws Exception { + Map arguments = new HashMap<>(); + arguments.put("name", "test"); + arguments.put("value", 42); + + McpSchema.CallToolRequest request = new McpSchema.CallToolRequest("test-tool", arguments); + + String value = JSON_MAPPER.writeValueAsString(request); + + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"name":"test-tool","arguments":{"name":"test","value":42}}""")); + } + + @Test + void testCallToolRequestJsonArguments() throws Exception { + + McpSchema.CallToolRequest request = new McpSchema.CallToolRequest(JSON_MAPPER, "test-tool", """ + { + "name": "test", + "value": 42 + } + """); + + String value = JSON_MAPPER.writeValueAsString(request); + + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"name":"test-tool","arguments":{"name":"test","value":42}}""")); + } + + @Test + void testCallToolRequestWithMeta() throws Exception { + + McpSchema.CallToolRequest request = McpSchema.CallToolRequest.builder() + .name("test-tool") + .arguments(Map.of("name", "test", "value", 42)) + .progressToken("tool-progress-123") + .build(); + String value = JSON_MAPPER.writeValueAsString(request); + + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"name":"test-tool","arguments":{"name":"test","value":42},"_meta":{"progressToken":"tool-progress-123"}}""")); + + // Test that it implements Request interface methods + assertThat(request.meta()).isEqualTo(Map.of("progressToken", "tool-progress-123")); + assertThat(request.progressToken()).isEqualTo("tool-progress-123"); + } + + @Test + void testCallToolRequestBuilderWithJsonArguments() throws Exception { + Map meta = new HashMap<>(); + meta.put("progressToken", "json-builder-789"); + + McpSchema.CallToolRequest request = McpSchema.CallToolRequest.builder() + .name("test-tool") + .arguments(JSON_MAPPER, """ + { + "name": "test", + "value": 42 + } + """) + .meta(meta) + .build(); + + String value = JSON_MAPPER.writeValueAsString(request); + + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"name":"test-tool","arguments":{"name":"test","value":42},"_meta":{"progressToken":"json-builder-789"}}""")); + + // Test that it implements Request interface methods + assertThat(request.meta()).isEqualTo(meta); + assertThat(request.progressToken()).isEqualTo("json-builder-789"); + } + + @Test + void testCallToolRequestBuilderNameRequired() { + Map arguments = new HashMap<>(); + arguments.put("name", "test"); + + McpSchema.CallToolRequest.Builder builder = McpSchema.CallToolRequest.builder().arguments(arguments); + + assertThatThrownBy(builder::build).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("name must not be empty"); + } + + @Test + void testCallToolResult() throws Exception { + McpSchema.TextContent content = new McpSchema.TextContent("Tool execution result"); + + McpSchema.CallToolResult result = McpSchema.CallToolResult.builder() + .content(Collections.singletonList(content)) + .build(); + + String value = JSON_MAPPER.writeValueAsString(result); + + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"content":[{"type":"text","text":"Tool execution result"}],"isError":false}""")); + } + + @Test + void testCallToolResultBuilder() throws Exception { + McpSchema.CallToolResult result = McpSchema.CallToolResult.builder() + .addTextContent("Tool execution result") + .isError(false) + .build(); + + String value = JSON_MAPPER.writeValueAsString(result); + + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"content":[{"type":"text","text":"Tool execution result"}],"isError":false}""")); + } + + @Test + void testCallToolResultBuilderWithMultipleContents() throws Exception { + McpSchema.TextContent textContent = new McpSchema.TextContent("Text result"); + McpSchema.ImageContent imageContent = new McpSchema.ImageContent(null, null, "base64data", "image/png"); + + McpSchema.CallToolResult result = McpSchema.CallToolResult.builder() + .addContent(textContent) + .addContent(imageContent) + .isError(false) + .build(); + + String value = JSON_MAPPER.writeValueAsString(result); + + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"content":[{"type":"text","text":"Text result"},{"type":"image","data":"base64data","mimeType":"image/png"}],"isError":false}""")); + } + + @Test + void testCallToolResultBuilderWithContentList() throws Exception { + McpSchema.TextContent textContent = new McpSchema.TextContent("Text result"); + McpSchema.ImageContent imageContent = new McpSchema.ImageContent(null, null, "base64data", "image/png"); + List contents = Arrays.asList(textContent, imageContent); + + McpSchema.CallToolResult result = McpSchema.CallToolResult.builder().content(contents).isError(true).build(); + + String value = JSON_MAPPER.writeValueAsString(result); + + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"content":[{"type":"text","text":"Text result"},{"type":"image","data":"base64data","mimeType":"image/png"}],"isError":true}""")); + } + + @Test + void testCallToolResultBuilderWithErrorResult() throws Exception { + McpSchema.CallToolResult result = McpSchema.CallToolResult.builder() + .addTextContent("Error: Operation failed") + .isError(true) + .build(); + + String value = JSON_MAPPER.writeValueAsString(result); + + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"content":[{"type":"text","text":"Error: Operation failed"}],"isError":true}""")); + } + + @Test + void testCallToolResultStringConstructor() throws Exception { + // Test the existing string constructor alongside the builder + McpSchema.CallToolResult result1 = new McpSchema.CallToolResult("Simple result", false); + McpSchema.CallToolResult result2 = McpSchema.CallToolResult.builder() + .addTextContent("Simple result") + .isError(false) + .build(); + + String value1 = JSON_MAPPER.writeValueAsString(result1); + String value2 = JSON_MAPPER.writeValueAsString(result2); + + // Both should produce the same JSON + assertThat(value1).isEqualTo(value2); + assertThatJson(value1).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"content":[{"type":"text","text":"Simple result"}],"isError":false}""")); + } + + // Sampling Tests + + @Test + void testCreateMessageRequest() throws Exception { + McpSchema.TextContent content = new McpSchema.TextContent("User message"); + + McpSchema.SamplingMessage message = new McpSchema.SamplingMessage(McpSchema.Role.USER, content); + + McpSchema.ModelHint hint = new McpSchema.ModelHint("gpt-4"); + + McpSchema.ModelPreferences preferences = new McpSchema.ModelPreferences(Collections.singletonList(hint), 0.3, + 0.7, 0.9); + + Map metadata = new HashMap<>(); + metadata.put("session", "test-session"); + + McpSchema.CreateMessageRequest request = McpSchema.CreateMessageRequest.builder() + .messages(Collections.singletonList(message)) + .modelPreferences(preferences) + .systemPrompt("You are a helpful assistant") + .includeContext(McpSchema.CreateMessageRequest.ContextInclusionStrategy.THIS_SERVER) + .temperature(0.7) + .maxTokens(1000) + .stopSequences(Arrays.asList("STOP", "END")) + .metadata(metadata) + .build(); + + String value = JSON_MAPPER.writeValueAsString(request); + + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"messages":[{"role":"user","content":{"type":"text","text":"User message"}}],"modelPreferences":{"hints":[{"name":"gpt-4"}],"costPriority":0.3,"speedPriority":0.7,"intelligencePriority":0.9},"systemPrompt":"You are a helpful assistant","includeContext":"thisServer","temperature":0.7,"maxTokens":1000,"stopSequences":["STOP","END"],"metadata":{"session":"test-session"}}""")); + } + + @Test + void testCreateMessageResult() throws Exception { + McpSchema.TextContent content = new McpSchema.TextContent("Assistant response"); + + McpSchema.CreateMessageResult result = McpSchema.CreateMessageResult.builder() + .role(McpSchema.Role.ASSISTANT) + .content(content) + .model("gpt-4") + .stopReason(McpSchema.CreateMessageResult.StopReason.END_TURN) + .build(); + + String value = JSON_MAPPER.writeValueAsString(result); + + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"role":"assistant","content":{"type":"text","text":"Assistant response"},"model":"gpt-4","stopReason":"endTurn"}""")); + } + + @Test + void testCreateMessageResultUnknownStopReason() throws Exception { + String input = """ + {"role":"assistant","content":{"type":"text","text":"Assistant response"},"model":"gpt-4","stopReason":"arbitrary value"}"""; + + McpSchema.CreateMessageResult value = JSON_MAPPER.readValue(input, McpSchema.CreateMessageResult.class); + + McpSchema.TextContent expectedContent = new McpSchema.TextContent("Assistant response"); + McpSchema.CreateMessageResult expected = McpSchema.CreateMessageResult.builder() + .role(McpSchema.Role.ASSISTANT) + .content(expectedContent) + .model("gpt-4") + .stopReason(McpSchema.CreateMessageResult.StopReason.UNKNOWN) + .build(); + assertThat(value).isEqualTo(expected); + } + + // Elicitation Tests + + @Test + void testCreateElicitationRequest() throws Exception { + McpSchema.ElicitRequest request = McpSchema.ElicitRequest.builder() + .requestedSchema(Map.of("type", "object", "required", List.of("a"), "properties", + Map.of("foo", Map.of("type", "string")))) + .build(); + + String value = JSON_MAPPER.writeValueAsString(request); + + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"requestedSchema":{"properties":{"foo":{"type":"string"}},"required":["a"],"type":"object"}}""")); + } + + @Test + void testCreateElicitationResult() throws Exception { + McpSchema.ElicitResult result = McpSchema.ElicitResult.builder() + .content(Map.of("foo", "bar")) + .message(McpSchema.ElicitResult.Action.ACCEPT) + .build(); + + String value = JSON_MAPPER.writeValueAsString(result); + + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"action":"accept","content":{"foo":"bar"}}""")); + } + + @Test + void testElicitRequestWithMeta() throws Exception { + Map requestedSchema = Map.of("type", "object", "required", List.of("name"), "properties", + Map.of("name", Map.of("type", "string"))); + + Map meta = new HashMap<>(); + meta.put("progressToken", "elicit-token-789"); + + McpSchema.ElicitRequest request = McpSchema.ElicitRequest.builder() + .message("Please provide your name") + .requestedSchema(requestedSchema) + .meta(meta) + .build(); + + String value = JSON_MAPPER.writeValueAsString(request); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .containsEntry("_meta", Map.of("progressToken", "elicit-token-789")); + + // Test Request interface methods + assertThat(request.meta()).isEqualTo(meta); + assertThat(request.progressToken()).isEqualTo("elicit-token-789"); + } + + // Pagination Tests + + @Test + void testPaginatedRequestNoArgs() throws Exception { + McpSchema.PaginatedRequest request = new McpSchema.PaginatedRequest(); + + String value = JSON_MAPPER.writeValueAsString(request); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {}""")); + + // Test that it implements Request interface methods + assertThat(request.meta()).isNull(); + assertThat(request.progressToken()).isNull(); + } + + @Test + void testPaginatedRequestWithCursor() throws Exception { + McpSchema.PaginatedRequest request = new McpSchema.PaginatedRequest("cursor123"); + + String value = JSON_MAPPER.writeValueAsString(request); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"cursor":"cursor123"}""")); + + // Test that it implements Request interface methods + assertThat(request.meta()).isNull(); + assertThat(request.progressToken()).isNull(); + } + + @Test + void testPaginatedRequestWithMeta() throws Exception { + Map meta = new HashMap<>(); + meta.put("progressToken", "pagination-progress-456"); + + McpSchema.PaginatedRequest request = new McpSchema.PaginatedRequest("cursor123", meta); + + String value = JSON_MAPPER.writeValueAsString(request); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"cursor":"cursor123","_meta":{"progressToken":"pagination-progress-456"}}""")); + + // Test that it implements Request interface methods + assertThat(request.meta()).isEqualTo(meta); + assertThat(request.progressToken()).isEqualTo("pagination-progress-456"); + } + + @Test + void testPaginatedRequestDeserialization() throws Exception { + McpSchema.PaginatedRequest request = JSON_MAPPER.readValue(""" + {"cursor":"test-cursor","_meta":{"progressToken":"test-token"}}""", McpSchema.PaginatedRequest.class); + + assertThat(request.cursor()).isEqualTo("test-cursor"); + assertThat(request.meta()).containsEntry("progressToken", "test-token"); + assertThat(request.progressToken()).isEqualTo("test-token"); + } + + // Complete Request Tests + + @Test + void testCompleteRequest() throws Exception { + McpSchema.PromptReference promptRef = new McpSchema.PromptReference("test-prompt"); + McpSchema.CompleteRequest.CompleteArgument argument = new McpSchema.CompleteRequest.CompleteArgument("arg1", + "partial-value"); + + McpSchema.CompleteRequest request = new McpSchema.CompleteRequest(promptRef, argument); + + String value = JSON_MAPPER.writeValueAsString(request); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"ref":{"type":"ref/prompt","name":"test-prompt"},"argument":{"name":"arg1","value":"partial-value"}}""")); + + // Test that it implements Request interface methods + assertThat(request.meta()).isNull(); + assertThat(request.progressToken()).isNull(); + } + + @Test + void testCompleteRequestWithMeta() throws Exception { + McpSchema.ResourceReference resourceRef = new McpSchema.ResourceReference("file:///test.txt"); + McpSchema.CompleteRequest.CompleteArgument argument = new McpSchema.CompleteRequest.CompleteArgument("path", + "/partial/path"); + + Map meta = new HashMap<>(); + meta.put("progressToken", "complete-progress-789"); + + McpSchema.CompleteRequest request = new McpSchema.CompleteRequest(resourceRef, argument, meta, null); + + String value = JSON_MAPPER.writeValueAsString(request); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"ref":{"type":"ref/resource","uri":"file:///test.txt"},"argument":{"name":"path","value":"/partial/path"},"_meta":{"progressToken":"complete-progress-789"}}""")); + + // Test that it implements Request interface methods + assertThat(request.meta()).isEqualTo(meta); + assertThat(request.progressToken()).isEqualTo("complete-progress-789"); + } + + // Roots Tests + + @Test + void testRoot() throws Exception { + McpSchema.Root root = new McpSchema.Root("file:///path/to/root", "Test Root", Map.of("metaKey", "metaValue")); + + String value = JSON_MAPPER.writeValueAsString(root); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"uri":"file:///path/to/root","name":"Test Root","_meta":{"metaKey":"metaValue"}}""")); + } + + @Test + void testListRootsResult() throws Exception { + McpSchema.Root root1 = new McpSchema.Root("file:///path/to/root1", "First Root"); + + McpSchema.Root root2 = new McpSchema.Root("file:///path/to/root2", "Second Root"); + + McpSchema.ListRootsResult result = new McpSchema.ListRootsResult(Arrays.asList(root1, root2), "next-cursor"); + + String value = JSON_MAPPER.writeValueAsString(result); + + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"roots":[{"uri":"file:///path/to/root1","name":"First Root"},{"uri":"file:///path/to/root2","name":"Second Root"}],"nextCursor":"next-cursor"}""")); + + } + + // Elicitation Capability Tests (Issue #724) + + @Test + void testElicitationCapabilityWithFormField() throws Exception { + // Test that elicitation with "form" field can be deserialized (2025-11-25 spec) + String json = """ + {"protocolVersion":"2024-11-05","capabilities":{"elicitation":{"form":{}}},"clientInfo":{"name":"test-client","version":"1.0.0"}} + """; + + McpSchema.InitializeRequest request = JSON_MAPPER.readValue(json, McpSchema.InitializeRequest.class); + + assertThat(request).isNotNull(); + assertThat(request.capabilities()).isNotNull(); + assertThat(request.capabilities().elicitation()).isNotNull(); + } + + @Test + void testElicitationCapabilityWithFormAndUrlFields() throws Exception { + // Test that elicitation with both "form" and "url" fields can be deserialized + String json = """ + {"protocolVersion":"2024-11-05","capabilities":{"elicitation":{"form":{},"url":{}}},"clientInfo":{"name":"test-client","version":"1.0.0"}} + """; + + McpSchema.InitializeRequest request = JSON_MAPPER.readValue(json, McpSchema.InitializeRequest.class); + + assertThat(request).isNotNull(); + assertThat(request.capabilities()).isNotNull(); + assertThat(request.capabilities().elicitation()).isNotNull(); + } + + @Test + void testElicitationCapabilityBackwardCompatibilityEmptyObject() throws Exception { + // Test backward compatibility: empty elicitation {} should still work + String json = """ + {"protocolVersion":"2024-11-05","capabilities":{"elicitation":{}},"clientInfo":{"name":"test-client","version":"1.0.0"}} + """; + + McpSchema.InitializeRequest request = JSON_MAPPER.readValue(json, McpSchema.InitializeRequest.class); + + assertThat(request).isNotNull(); + assertThat(request.capabilities()).isNotNull(); + assertThat(request.capabilities().elicitation()).isNotNull(); + } + + @Test + void testElicitationCapabilityBuilderBackwardCompatibility() throws Exception { + // Test that the existing builder API still works and produces valid JSON + McpSchema.ClientCapabilities capabilities = McpSchema.ClientCapabilities.builder().elicitation().build(); + + assertThat(capabilities.elicitation()).isNotNull(); + + // Serialize and verify it produces valid JSON (should be {} for backward compat) + String json = JSON_MAPPER.writeValueAsString(capabilities); + assertThat(json).contains("\"elicitation\""); + } + + @Test + void testElicitationCapabilitySerializationRoundTrip() throws Exception { + // Test that serialization and deserialization round-trip works + McpSchema.ClientCapabilities original = McpSchema.ClientCapabilities.builder().elicitation().build(); + + String json = JSON_MAPPER.writeValueAsString(original); + McpSchema.ClientCapabilities deserialized = JSON_MAPPER.readValue(json, McpSchema.ClientCapabilities.class); + + assertThat(deserialized.elicitation()).isNotNull(); + } + + @Test + void testElicitationCapabilityBuilderWithFormAndUrl() throws Exception { + // Test the new builder method that explicitly sets form and url support + McpSchema.ClientCapabilities capabilities = McpSchema.ClientCapabilities.builder() + .elicitation(true, true) + .build(); + + assertThat(capabilities.elicitation()).isNotNull(); + assertThat(capabilities.elicitation().form()).isNotNull(); + assertThat(capabilities.elicitation().url()).isNotNull(); + + // Verify serialization produces the expected JSON + String json = JSON_MAPPER.writeValueAsString(capabilities); + assertThatJson(json).when(Option.IGNORING_ARRAY_ORDER).isObject().containsKey("elicitation"); + assertThat(json).contains("\"form\""); + assertThat(json).contains("\"url\""); + } + + @Test + void testElicitationCapabilityBuilderFormOnly() throws Exception { + // Test builder with form only + McpSchema.ClientCapabilities capabilities = McpSchema.ClientCapabilities.builder() + .elicitation(true, false) + .build(); + + assertThat(capabilities.elicitation()).isNotNull(); + assertThat(capabilities.elicitation().form()).isNotNull(); + assertThat(capabilities.elicitation().url()).isNull(); + + String json = JSON_MAPPER.writeValueAsString(capabilities); + assertThat(json).contains("\"form\""); + assertThat(json).doesNotContain("\"url\""); + } + + // Progress Notification Tests + + @Test + void testProgressNotificationWithMessage() throws Exception { + McpSchema.ProgressNotification notification = new McpSchema.ProgressNotification("progress-token-123", 0.5, 1.0, + "Processing file 1 of 2", Map.of("key", "value")); + + String value = JSON_MAPPER.writeValueAsString(notification); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"progressToken":"progress-token-123","progress":0.5,"total":1.0,"message":"Processing file 1 of 2","_meta":{"key":"value"}}""")); + } + + @Test + void testProgressNotificationDeserialization() throws Exception { + McpSchema.ProgressNotification notification = JSON_MAPPER.readValue( + """ + {"progressToken":"token-456","progress":0.75,"total":1.0,"message":"Almost done","_meta":{"key":"value"}}""", + McpSchema.ProgressNotification.class); + + assertThat(notification.progressToken()).isEqualTo("token-456"); + assertThat(notification.progress()).isEqualTo(0.75); + assertThat(notification.total()).isEqualTo(1.0); + assertThat(notification.message()).isEqualTo("Almost done"); + assertThat(notification.meta()).containsEntry("key", "value"); + } + + @Test + void testProgressNotificationWithoutMessage() throws Exception { + McpSchema.ProgressNotification notification = new McpSchema.ProgressNotification("progress-token-789", 0.25, + null, null); + + String value = JSON_MAPPER.writeValueAsString(notification); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"progressToken":"progress-token-789","progress":0.25}""")); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/spec/PromptReferenceEqualsTest.java b/mcp-core/src/test/java/io/modelcontextprotocol/spec/PromptReferenceEqualsTest.java new file mode 100644 index 000000000..1d7be0b51 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/spec/PromptReferenceEqualsTest.java @@ -0,0 +1,100 @@ +/* +* Copyright 2025 - 2025 the original author or authors. +*/ + +package io.modelcontextprotocol.spec; + +import io.modelcontextprotocol.spec.McpSchema.PromptReference; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertEquals; + +/** + * Test class to verify the equals method implementation for PromptReference. + */ +class PromptReferenceEqualsTest { + + @Test + void testEqualsWithSameIdentifierAndType() { + McpSchema.PromptReference ref1 = new McpSchema.PromptReference(PromptReference.TYPE, "test-prompt", + "Test Title"); + McpSchema.PromptReference ref2 = new McpSchema.PromptReference(PromptReference.TYPE, "test-prompt", + "Different Title"); + + assertTrue(ref1.equals(ref2), "PromptReferences with same identifier and type should be equal"); + assertEquals(ref1.hashCode(), ref2.hashCode(), "Equal objects should have same hash code"); + } + + @Test + void testEqualsWithDifferentIdentifier() { + McpSchema.PromptReference ref1 = new McpSchema.PromptReference(PromptReference.TYPE, "test-prompt-1", + "Test Title"); + McpSchema.PromptReference ref2 = new McpSchema.PromptReference(PromptReference.TYPE, "test-prompt-2", + "Test Title"); + + assertFalse(ref1.equals(ref2), "PromptReferences with different identifiers should not be equal"); + } + + @Test + void testEqualsWithDifferentType() { + McpSchema.PromptReference ref1 = new McpSchema.PromptReference(PromptReference.TYPE, "test-prompt", + "Test Title"); + McpSchema.PromptReference ref2 = new McpSchema.PromptReference("ref/other", "test-prompt", "Test Title"); + + assertFalse(ref1.equals(ref2), "PromptReferences with different types should not be equal"); + } + + @Test + void testEqualsWithNull() { + McpSchema.PromptReference ref1 = new McpSchema.PromptReference(PromptReference.TYPE, "test-prompt", + "Test Title"); + + assertFalse(ref1.equals(null), "PromptReference should not be equal to null"); + } + + @Test + void testEqualsWithDifferentClass() { + McpSchema.PromptReference ref1 = new McpSchema.PromptReference(PromptReference.TYPE, "test-prompt", + "Test Title"); + String other = "not a PromptReference"; + + assertFalse(ref1.equals(other), "PromptReference should not be equal to different class"); + } + + @Test + void testEqualsWithSameInstance() { + McpSchema.PromptReference ref1 = new McpSchema.PromptReference(PromptReference.TYPE, "test-prompt", + "Test Title"); + + assertTrue(ref1.equals(ref1), "PromptReference should be equal to itself"); + } + + @Test + void testEqualsIgnoresTitle() { + McpSchema.PromptReference ref1 = new McpSchema.PromptReference(PromptReference.TYPE, "test-prompt", "Title 1"); + McpSchema.PromptReference ref2 = new McpSchema.PromptReference(PromptReference.TYPE, "test-prompt", "Title 2"); + McpSchema.PromptReference ref3 = new McpSchema.PromptReference(PromptReference.TYPE, "test-prompt", null); + + assertTrue(ref1.equals(ref2), "PromptReferences should be equal regardless of title"); + assertTrue(ref1.equals(ref3), "PromptReferences should be equal even when one has null title"); + assertTrue(ref2.equals(ref3), "PromptReferences should be equal even when one has null title"); + } + + @Test + void testHashCodeConsistency() { + McpSchema.PromptReference ref1 = new McpSchema.PromptReference(PromptReference.TYPE, "test-prompt", + "Test Title"); + McpSchema.PromptReference ref2 = new McpSchema.PromptReference(PromptReference.TYPE, "test-prompt", + "Different Title"); + + assertEquals(ref1.hashCode(), ref2.hashCode(), "Objects that are equal should have the same hash code"); + + // Call hashCode multiple times to ensure consistency + int hashCode1 = ref1.hashCode(); + int hashCode2 = ref1.hashCode(); + assertEquals(hashCode1, hashCode2, "Hash code should be consistent across multiple calls"); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/spec/json/gson/GsonMcpJsonMapper.java b/mcp-core/src/test/java/io/modelcontextprotocol/spec/json/gson/GsonMcpJsonMapper.java new file mode 100644 index 000000000..ef7cd2737 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/spec/json/gson/GsonMcpJsonMapper.java @@ -0,0 +1,97 @@ +package io.modelcontextprotocol.spec.json.gson; + +import com.google.gson.Gson; +import com.google.gson.GsonBuilder; +import com.google.gson.ToNumberPolicy; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.TypeRef; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; + +/** + * Test-only Gson-based implementation of McpJsonMapper. This lives under src/test/java so + * it doesn't affect production code or dependencies. + */ +public final class GsonMcpJsonMapper implements McpJsonMapper { + + private final Gson gson; + + public GsonMcpJsonMapper() { + this(new GsonBuilder().serializeNulls() + // Ensure numeric values in untyped (Object) fields preserve integral numbers + // as Long + .setObjectToNumberStrategy(ToNumberPolicy.LONG_OR_DOUBLE) + .setNumberToNumberStrategy(ToNumberPolicy.LONG_OR_DOUBLE) + .create()); + } + + public GsonMcpJsonMapper(Gson gson) { + if (gson == null) { + throw new IllegalArgumentException("Gson must not be null"); + } + this.gson = gson; + } + + public Gson getGson() { + return gson; + } + + @Override + public T readValue(String content, Class type) throws IOException { + try { + return gson.fromJson(content, type); + } + catch (Exception e) { + throw new IOException("Failed to deserialize JSON", e); + } + } + + @Override + public T readValue(byte[] content, Class type) throws IOException { + return readValue(new String(content, StandardCharsets.UTF_8), type); + } + + @Override + public T readValue(String content, TypeRef type) throws IOException { + try { + return gson.fromJson(content, type.getType()); + } + catch (Exception e) { + throw new IOException("Failed to deserialize JSON", e); + } + } + + @Override + public T readValue(byte[] content, TypeRef type) throws IOException { + return readValue(new String(content, StandardCharsets.UTF_8), type); + } + + @Override + public T convertValue(Object fromValue, Class type) { + String json = gson.toJson(fromValue); + return gson.fromJson(json, type); + } + + @Override + public T convertValue(Object fromValue, TypeRef type) { + String json = gson.toJson(fromValue); + return gson.fromJson(json, type.getType()); + } + + @Override + public String writeValueAsString(Object value) throws IOException { + try { + return gson.toJson(value); + } + catch (Exception e) { + throw new IOException("Failed to serialize to JSON", e); + } + } + + @Override + public byte[] writeValueAsBytes(Object value) throws IOException { + return writeValueAsString(value).getBytes(StandardCharsets.UTF_8); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/spec/json/gson/GsonMcpJsonMapperTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/spec/json/gson/GsonMcpJsonMapperTests.java new file mode 100644 index 000000000..498194d17 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/spec/json/gson/GsonMcpJsonMapperTests.java @@ -0,0 +1,135 @@ +package io.modelcontextprotocol.spec.json.gson; + +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.json.TypeRef; +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; + +class GsonMcpJsonMapperTests { + + record Person(String name, int age) { + } + + @Test + void roundTripSimplePojo() throws IOException { + var mapper = new GsonMcpJsonMapper(); + + var input = new Person("Alice", 30); + String json = mapper.writeValueAsString(input); + assertNotNull(json); + assertTrue(json.contains("\"Alice\"")); + assertTrue(json.contains("\"age\"")); + + var decoded = mapper.readValue(json, Person.class); + assertEquals(input, decoded); + + byte[] bytes = mapper.writeValueAsBytes(input); + assertNotNull(bytes); + var decodedFromBytes = mapper.readValue(bytes, Person.class); + assertEquals(input, decodedFromBytes); + } + + @Test + void readWriteParameterizedTypeWithTypeRef() throws IOException { + var mapper = new GsonMcpJsonMapper(); + String json = "[\"a\", \"b\", \"c\"]"; + + List list = mapper.readValue(json, new TypeRef>() { + }); + assertEquals(List.of("a", "b", "c"), list); + + String encoded = mapper.writeValueAsString(list); + assertTrue(encoded.startsWith("[")); + assertTrue(encoded.contains("\"a\"")); + } + + @Test + void convertValueMapToRecordAndParameterized() { + var mapper = new GsonMcpJsonMapper(); + Map src = Map.of("name", "Bob", "age", 42); + + // Convert to simple record + Person person = mapper.convertValue(src, Person.class); + assertEquals(new Person("Bob", 42), person); + + // Convert to parameterized Map + Map toMap = mapper.convertValue(person, new TypeRef>() { + }); + assertEquals("Bob", toMap.get("name")); + assertEquals(42.0, ((Number) toMap.get("age")).doubleValue(), 0.0); // Gson may + // emit double + // for + // primitives + } + + @Test + void deserializeJsonRpcMessageRequestUsingCustomMapper() throws IOException { + var mapper = new GsonMcpJsonMapper(); + + String json = """ + { + "jsonrpc": "2.0", + "id": 1, + "method": "ping", + "params": { "x": 1, "y": "z" } + } + """; + + var msg = McpSchema.deserializeJsonRpcMessage(mapper, json); + assertTrue(msg instanceof McpSchema.JSONRPCRequest); + + var req = (McpSchema.JSONRPCRequest) msg; + assertEquals("2.0", req.jsonrpc()); + assertEquals("ping", req.method()); + assertNotNull(req.id()); + assertEquals("1", req.id().toString()); + + assertNotNull(req.params()); + assertInstanceOf(Map.class, req.params()); + @SuppressWarnings("unchecked") + var params = (Map) req.params(); + assertEquals(1.0, ((Number) params.get("x")).doubleValue(), 0.0); + assertEquals("z", params.get("y")); + } + + @Test + void integrateWithMcpSchemaStaticMapperForStringParsing() { + var gsonMapper = new GsonMcpJsonMapper(); + + // Tool builder parsing of input/output schema strings + var tool = McpSchema.Tool.builder().name("echo").description("Echo tool").inputSchema(gsonMapper, """ + { + "type": "object", + "properties": { "x": { "type": "integer" } }, + "required": ["x"] + } + """).outputSchema(gsonMapper, """ + { + "type": "object", + "properties": { "y": { "type": "string" } } + } + """).build(); + + assertNotNull(tool.inputSchema()); + assertNotNull(tool.outputSchema()); + assertTrue(tool.outputSchema().containsKey("properties")); + + // CallToolRequest builder parsing of JSON arguments string + var call = McpSchema.CallToolRequest.builder().name("echo").arguments(gsonMapper, "{\"x\": 123}").build(); + + assertEquals("echo", call.name()); + assertNotNull(call.arguments()); + assertTrue(call.arguments().get("x") instanceof Number); + assertEquals(123.0, ((Number) call.arguments().get("x")).doubleValue(), 0.0); + + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/util/AssertTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/util/AssertTests.java similarity index 87% rename from mcp/src/test/java/io/modelcontextprotocol/util/AssertTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/util/AssertTests.java index 08555fef5..0038d4e1b 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/util/AssertTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/util/AssertTests.java @@ -8,7 +8,9 @@ import java.util.List; -import static org.junit.jupiter.api.Assertions.*; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; class AssertTests { diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/util/KeepAliveSchedulerTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/util/KeepAliveSchedulerTests.java new file mode 100644 index 000000000..d5ef8a91c --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/util/KeepAliveSchedulerTests.java @@ -0,0 +1,303 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package io.modelcontextprotocol.util; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; + +import java.time.Duration; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Supplier; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import io.modelcontextprotocol.json.TypeRef; + +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSession; +import reactor.core.Disposable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.test.scheduler.VirtualTimeScheduler; + +/** + * Unit tests for {@link KeepAliveScheduler}. + * + * @author Christian Tzolov + */ +class KeepAliveSchedulerTests { + + private MockMcpSession mockSession1; + + private MockMcpSession mockSession2; + + private Supplier> mockSessionsSupplier; + + private VirtualTimeScheduler virtualTimeScheduler; + + @BeforeEach + void setUp() { + virtualTimeScheduler = VirtualTimeScheduler.create(); + mockSession1 = new MockMcpSession(); + mockSession2 = new MockMcpSession(); + mockSessionsSupplier = () -> Flux.just(mockSession1); + } + + @AfterEach + void tearDown() { + if (virtualTimeScheduler != null) { + virtualTimeScheduler.dispose(); + } + } + + @Test + void testBuilderWithNullSessionsSupplier() { + assertThatThrownBy(() -> KeepAliveScheduler.builder(null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("McpSessions supplier must not be null"); + } + + @Test + void testBuilderWithNullScheduler() { + assertThatThrownBy(() -> KeepAliveScheduler.builder(mockSessionsSupplier).scheduler(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Scheduler must not be null"); + } + + @Test + void testBuilderWithNullInitialDelay() { + assertThatThrownBy(() -> KeepAliveScheduler.builder(mockSessionsSupplier).initialDelay(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Initial delay must not be null"); + } + + @Test + void testBuilderWithNullInterval() { + assertThatThrownBy(() -> KeepAliveScheduler.builder(mockSessionsSupplier).interval(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Interval must not be null"); + } + + @Test + void testBuilderDefaults() { + KeepAliveScheduler scheduler = KeepAliveScheduler.builder(mockSessionsSupplier).build(); + + assertThat(scheduler).isNotNull(); + assertThat(scheduler.isRunning()).isFalse(); + } + + @Test + void testStartWithMultipleSessions() { + mockSessionsSupplier = () -> Flux.just(mockSession1, mockSession2); + + KeepAliveScheduler scheduler = KeepAliveScheduler.builder(mockSessionsSupplier) + .scheduler(virtualTimeScheduler) + .initialDelay(Duration.ofSeconds(1)) + .interval(Duration.ofSeconds(2)) + .build(); + + assertThat(scheduler.isRunning()).isFalse(); + + // Start the scheduler + Disposable disposable = scheduler.start(); + + assertThat(scheduler.isRunning()).isTrue(); + assertThat(disposable).isNotNull(); + assertThat(disposable.isDisposed()).isFalse(); + + // Advance time to trigger the first ping + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(1)); + + // Verify both sessions received ping + assertThat(mockSession1.getPingCount()).isEqualTo(1); + assertThat(mockSession2.getPingCount()).isEqualTo(1); + + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(2)); // Second ping + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(2)); // Third ping + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(2)); // Fourth ping + + // Verify second ping was sent + assertThat(mockSession1.getPingCount()).isEqualTo(4); + assertThat(mockSession2.getPingCount()).isEqualTo(4); + + // Clean up + scheduler.stop(); + + assertThat(scheduler.isRunning()).isFalse(); + assertThat(disposable).isNotNull(); + assertThat(disposable.isDisposed()).isTrue(); + } + + @Test + void testStartWithEmptySessionsList() { + mockSessionsSupplier = () -> Flux.empty(); + + KeepAliveScheduler scheduler = KeepAliveScheduler.builder(mockSessionsSupplier) + .scheduler(virtualTimeScheduler) + .initialDelay(Duration.ofSeconds(1)) + .interval(Duration.ofSeconds(2)) + .build(); + + // Start the scheduler + scheduler.start(); + + // Advance time to trigger ping attempts + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(1)); + + // Verify no sessions were called (since list was empty) + assertThat(mockSession1.getPingCount()).isEqualTo(0); + assertThat(mockSession2.getPingCount()).isEqualTo(0); + + // Clean up + scheduler.stop(); + } + + @Test + void testStartWhenAlreadyRunning() { + KeepAliveScheduler scheduler = KeepAliveScheduler.builder(mockSessionsSupplier) + .scheduler(virtualTimeScheduler) + .build(); + + // Start the scheduler + scheduler.start(); + + // Try to start again - should throw exception + assertThatThrownBy(scheduler::start).isInstanceOf(IllegalStateException.class) + .hasMessage("KeepAlive scheduler is already running. Stop it first."); + + // Clean up + scheduler.stop(); + } + + @Test + void testStopWhenNotRunning() { + KeepAliveScheduler scheduler = KeepAliveScheduler.builder(mockSessionsSupplier) + .scheduler(virtualTimeScheduler) + .build(); + + // Should not throw exception when stopping a non-running scheduler + assertDoesNotThrow(scheduler::stop); + assertThat(scheduler.isRunning()).isFalse(); + } + + @Test + void testShutdown() { + // Setup with a separate virtual time scheduler (which is disposable) + VirtualTimeScheduler separateScheduler = VirtualTimeScheduler.create(); + KeepAliveScheduler scheduler = KeepAliveScheduler.builder(mockSessionsSupplier) + .scheduler(separateScheduler) + .build(); + + // Start the scheduler + scheduler.start(); + assertThat(scheduler.isRunning()).isTrue(); + + // Shutdown should stop the scheduler and dispose the scheduler + scheduler.shutdown(); + assertThat(scheduler.isRunning()).isFalse(); + assertThat(separateScheduler.isDisposed()).isTrue(); + } + + @Test + void testPingFailureHandling() { + // Setup session that fails ping + mockSession1.setShouldFailPing(true); + + KeepAliveScheduler scheduler = KeepAliveScheduler.builder(mockSessionsSupplier) + .scheduler(virtualTimeScheduler) + .initialDelay(Duration.ofSeconds(1)) + .interval(Duration.ofSeconds(2)) + .build(); + + // Start the scheduler + scheduler.start(); + + // Advance time to trigger the ping + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(1)); + + // Verify ping was attempted (error should be handled gracefully) + assertThat(mockSession1.getPingCount()).isEqualTo(1); + + // Scheduler should still be running despite the error + assertThat(scheduler.isRunning()).isTrue(); + + // Clean up + scheduler.stop(); + } + + @Test + void testDisposableReturnedFromStart() { + KeepAliveScheduler scheduler = KeepAliveScheduler.builder(mockSessionsSupplier) + .scheduler(virtualTimeScheduler) + .build(); + + // Start and get disposable + Disposable disposable = scheduler.start(); + + assertThat(disposable).isNotNull(); + assertThat(disposable.isDisposed()).isFalse(); + assertThat(scheduler.isRunning()).isTrue(); + + // Dispose directly through the returned disposable + disposable.dispose(); + + assertThat(disposable.isDisposed()).isTrue(); + assertThat(scheduler.isRunning()).isFalse(); + } + + /** + * Simple mock implementation of McpSession for testing purposes. + */ + private static class MockMcpSession implements McpSession { + + private final AtomicInteger pingCount = new AtomicInteger(0); + + private boolean shouldFailPing = false; + + @Override + public Mono sendRequest(String method, Object requestParams, TypeRef typeRef) { + if (McpSchema.METHOD_PING.equals(method)) { + pingCount.incrementAndGet(); + if (shouldFailPing) { + return Mono.error(new RuntimeException("Connection failed")); + } + return Mono.just((T) new Object()); + } + return Mono.empty(); + } + + @Override + public Mono sendNotification(String method, Object params) { + return Mono.empty(); + } + + @Override + public Mono closeGracefully() { + return Mono.empty(); + } + + @Override + public void close() { + // No-op for mock + } + + public int getPingCount() { + return pingCount.get(); + } + + public void setShouldFailPing(boolean shouldFailPing) { + this.shouldFailPing = shouldFailPing; + } + + @Override + public String toString() { + return "MockMcpSession"; + } + + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/util/McpJsonMapperUtils.java b/mcp-core/src/test/java/io/modelcontextprotocol/util/McpJsonMapperUtils.java new file mode 100644 index 000000000..911506e01 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/util/McpJsonMapperUtils.java @@ -0,0 +1,12 @@ +package io.modelcontextprotocol.util; + +import io.modelcontextprotocol.json.McpJsonMapper; + +public final class McpJsonMapperUtils { + + private McpJsonMapperUtils() { + } + + public static final McpJsonMapper JSON_MAPPER = McpJsonMapper.getDefault(); + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/util/ToolsUtils.java b/mcp-core/src/test/java/io/modelcontextprotocol/util/ToolsUtils.java new file mode 100644 index 000000000..ce8755223 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/util/ToolsUtils.java @@ -0,0 +1,15 @@ +package io.modelcontextprotocol.util; + +import io.modelcontextprotocol.spec.McpSchema; + +import java.util.Collections; + +public final class ToolsUtils { + + private ToolsUtils() { + } + + public static final McpSchema.JsonSchema EMPTY_JSON_SCHEMA = new McpSchema.JsonSchema("object", + Collections.emptyMap(), null, null, null, null); + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/util/UtilsTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/util/UtilsTests.java similarity index 100% rename from mcp/src/test/java/io/modelcontextprotocol/util/UtilsTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/util/UtilsTests.java diff --git a/mcp/src/test/resources/logback.xml b/mcp-core/src/test/resources/logback.xml similarity index 100% rename from mcp/src/test/resources/logback.xml rename to mcp-core/src/test/resources/logback.xml diff --git a/mcp-json-jackson2/pom.xml b/mcp-json-jackson2/pom.xml new file mode 100644 index 000000000..de2ac58ce --- /dev/null +++ b/mcp-json-jackson2/pom.xml @@ -0,0 +1,79 @@ + + + 4.0.0 + + io.modelcontextprotocol.sdk + mcp-parent + 0.18.0-SNAPSHOT + + mcp-json-jackson2 + jar + Java MCP SDK JSON Jackson + Java MCP SDK JSON implementation based on Jackson + https://github.com/modelcontextprotocol/java-sdk + + https://github.com/modelcontextprotocol/java-sdk + git://github.com/modelcontextprotocol/java-sdk.git + git@github.com/modelcontextprotocol/java-sdk.git + + + + + org.apache.maven.plugins + maven-jar-plugin + + + + true + + + + + + + + + io.modelcontextprotocol.sdk + mcp-json + 0.18.0-SNAPSHOT + + + com.fasterxml.jackson.core + jackson-databind + ${jackson.version} + + + com.networknt + json-schema-validator + ${json-schema-validator.version} + + + + org.assertj + assertj-core + ${assert4j.version} + test + + + org.junit.jupiter + junit-jupiter-api + ${junit.version} + test + + + org.junit.jupiter + junit-jupiter-params + ${junit.version} + test + + + org.mockito + mockito-core + ${mockito.version} + test + + + + diff --git a/mcp-json-jackson2/src/main/java/io/modelcontextprotocol/json/jackson/JacksonMcpJsonMapper.java b/mcp-json-jackson2/src/main/java/io/modelcontextprotocol/json/jackson/JacksonMcpJsonMapper.java new file mode 100644 index 000000000..6aa2b4ebc --- /dev/null +++ b/mcp-json-jackson2/src/main/java/io/modelcontextprotocol/json/jackson/JacksonMcpJsonMapper.java @@ -0,0 +1,87 @@ +/* + * Copyright 2025 - 2025 the original author or authors. + */ + +package io.modelcontextprotocol.json.jackson; + +import com.fasterxml.jackson.databind.JavaType; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.TypeRef; + +import java.io.IOException; + +/** + * Jackson-based implementation of JsonMapper. Wraps a Jackson ObjectMapper but keeps the + * SDK decoupled from Jackson at the API level. + */ +public final class JacksonMcpJsonMapper implements McpJsonMapper { + + private final ObjectMapper objectMapper; + + /** + * Constructs a new JacksonMcpJsonMapper instance with the given ObjectMapper. + * @param objectMapper the ObjectMapper to be used for JSON serialization and + * deserialization. Must not be null. + * @throws IllegalArgumentException if the provided ObjectMapper is null. + */ + public JacksonMcpJsonMapper(ObjectMapper objectMapper) { + if (objectMapper == null) { + throw new IllegalArgumentException("ObjectMapper must not be null"); + } + this.objectMapper = objectMapper; + } + + /** + * Returns the underlying Jackson {@link ObjectMapper} used for JSON serialization and + * deserialization. + * @return the ObjectMapper instance + */ + public ObjectMapper getObjectMapper() { + return objectMapper; + } + + @Override + public T readValue(String content, Class type) throws IOException { + return objectMapper.readValue(content, type); + } + + @Override + public T readValue(byte[] content, Class type) throws IOException { + return objectMapper.readValue(content, type); + } + + @Override + public T readValue(String content, TypeRef type) throws IOException { + JavaType javaType = objectMapper.getTypeFactory().constructType(type.getType()); + return objectMapper.readValue(content, javaType); + } + + @Override + public T readValue(byte[] content, TypeRef type) throws IOException { + JavaType javaType = objectMapper.getTypeFactory().constructType(type.getType()); + return objectMapper.readValue(content, javaType); + } + + @Override + public T convertValue(Object fromValue, Class type) { + return objectMapper.convertValue(fromValue, type); + } + + @Override + public T convertValue(Object fromValue, TypeRef type) { + JavaType javaType = objectMapper.getTypeFactory().constructType(type.getType()); + return objectMapper.convertValue(fromValue, javaType); + } + + @Override + public String writeValueAsString(Object value) throws IOException { + return objectMapper.writeValueAsString(value); + } + + @Override + public byte[] writeValueAsBytes(Object value) throws IOException { + return objectMapper.writeValueAsBytes(value); + } + +} diff --git a/mcp-json-jackson2/src/main/java/io/modelcontextprotocol/json/jackson/JacksonMcpJsonMapperSupplier.java b/mcp-json-jackson2/src/main/java/io/modelcontextprotocol/json/jackson/JacksonMcpJsonMapperSupplier.java new file mode 100644 index 000000000..0e79c3e0e --- /dev/null +++ b/mcp-json-jackson2/src/main/java/io/modelcontextprotocol/json/jackson/JacksonMcpJsonMapperSupplier.java @@ -0,0 +1,32 @@ +/* + * Copyright 2025 - 2025 the original author or authors. + */ + +package io.modelcontextprotocol.json.jackson; + +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.McpJsonMapperSupplier; + +/** + * A supplier of {@link McpJsonMapper} instances that uses the Jackson library for JSON + * serialization and deserialization. + *

+ * This implementation provides a {@link McpJsonMapper} backed by a Jackson + * {@link com.fasterxml.jackson.databind.ObjectMapper}. + */ +public class JacksonMcpJsonMapperSupplier implements McpJsonMapperSupplier { + + /** + * Returns a new instance of {@link McpJsonMapper} that uses the Jackson library for + * JSON serialization and deserialization. + *

+ * The returned {@link McpJsonMapper} is backed by a new instance of + * {@link com.fasterxml.jackson.databind.ObjectMapper}. + * @return a new {@link McpJsonMapper} instance + */ + @Override + public McpJsonMapper get() { + return new JacksonMcpJsonMapper(new com.fasterxml.jackson.databind.ObjectMapper()); + } + +} diff --git a/mcp-json-jackson2/src/main/java/io/modelcontextprotocol/json/schema/jackson/DefaultJsonSchemaValidator.java b/mcp-json-jackson2/src/main/java/io/modelcontextprotocol/json/schema/jackson/DefaultJsonSchemaValidator.java new file mode 100644 index 000000000..1ff28cb80 --- /dev/null +++ b/mcp-json-jackson2/src/main/java/io/modelcontextprotocol/json/schema/jackson/DefaultJsonSchemaValidator.java @@ -0,0 +1,161 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ +package io.modelcontextprotocol.json.schema.jackson; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.networknt.schema.Schema; +import com.networknt.schema.SchemaRegistry; +import com.networknt.schema.Error; +import com.networknt.schema.dialect.Dialects; +import io.modelcontextprotocol.json.schema.JsonSchemaValidator; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Default implementation of the {@link JsonSchemaValidator} interface. This class + * provides methods to validate structured content against a JSON schema. It uses the + * NetworkNT JSON Schema Validator library for validation. + * + * @author Christian Tzolov + */ +public class DefaultJsonSchemaValidator implements JsonSchemaValidator { + + private static final Logger logger = LoggerFactory.getLogger(DefaultJsonSchemaValidator.class); + + private final ObjectMapper objectMapper; + + private final SchemaRegistry schemaFactory; + + // TODO: Implement a strategy to purge the cache (TTL, size limit, etc.) + private final ConcurrentHashMap schemaCache; + + public DefaultJsonSchemaValidator() { + this(new ObjectMapper()); + } + + public DefaultJsonSchemaValidator(ObjectMapper objectMapper) { + this.objectMapper = objectMapper; + this.schemaFactory = SchemaRegistry.withDialect(Dialects.getDraft202012()); + this.schemaCache = new ConcurrentHashMap<>(); + } + + @Override + public ValidationResponse validate(Map schema, Object structuredContent) { + + if (schema == null) { + throw new IllegalArgumentException("Schema must not be null"); + } + if (structuredContent == null) { + throw new IllegalArgumentException("Structured content must not be null"); + } + + try { + + JsonNode jsonStructuredOutput = (structuredContent instanceof String) + ? this.objectMapper.readTree((String) structuredContent) + : this.objectMapper.valueToTree(structuredContent); + + List validationResult = this.getOrCreateJsonSchema(schema).validate(jsonStructuredOutput); + + // Check if validation passed + if (!validationResult.isEmpty()) { + return ValidationResponse + .asInvalid("Validation failed: structuredContent does not match tool outputSchema. " + + "Validation errors: " + validationResult); + } + + return ValidationResponse.asValid(jsonStructuredOutput.toString()); + + } + catch (JsonProcessingException e) { + logger.error("Failed to validate CallToolResult: Error parsing schema: {}", e); + return ValidationResponse.asInvalid("Error parsing tool JSON Schema: " + e.getMessage()); + } + catch (Exception e) { + logger.error("Failed to validate CallToolResult: Unexpected error: {}", e); + return ValidationResponse.asInvalid("Unexpected validation error: " + e.getMessage()); + } + } + + /** + * Gets a cached Schema or creates and caches a new one. + * @param schema the schema map to convert + * @return the compiled Schema + * @throws JsonProcessingException if schema processing fails + */ + private Schema getOrCreateJsonSchema(Map schema) throws JsonProcessingException { + // Generate cache key based on schema content + String cacheKey = this.generateCacheKey(schema); + + // Try to get from cache first + Schema cachedSchema = this.schemaCache.get(cacheKey); + if (cachedSchema != null) { + return cachedSchema; + } + + // Create new schema if not in cache + Schema newSchema = this.createJsonSchema(schema); + + // Cache the schema + Schema existingSchema = this.schemaCache.putIfAbsent(cacheKey, newSchema); + return existingSchema != null ? existingSchema : newSchema; + } + + /** + * Creates a new Schema from the given schema map. + * @param schema the schema map + * @return the compiled Schema + * @throws JsonProcessingException if schema processing fails + */ + private Schema createJsonSchema(Map schema) throws JsonProcessingException { + // Convert schema map directly to JsonNode (more efficient than string + // serialization) + JsonNode schemaNode = this.objectMapper.valueToTree(schema); + + // Handle case where ObjectMapper might return null (e.g., in mocked scenarios) + if (schemaNode == null) { + throw new JsonProcessingException("Failed to convert schema to JsonNode") { + }; + } + + return this.schemaFactory.getSchema(schemaNode); + } + + /** + * Generates a cache key for the given schema map. + * @param schema the schema map + * @return a cache key string + */ + protected String generateCacheKey(Map schema) { + if (schema.containsKey("$id")) { + // Use the (optional) "$id" field as the cache key if present + return "" + schema.get("$id"); + } + // Fall back to schema's hash code as a simple cache key + // For more sophisticated caching, could use content-based hashing + return String.valueOf(schema.hashCode()); + } + + /** + * Clears the schema cache. Useful for testing or memory management. + */ + public void clearCache() { + this.schemaCache.clear(); + } + + /** + * Returns the current size of the schema cache. + * @return the number of cached schemas + */ + public int getCacheSize() { + return this.schemaCache.size(); + } + +} diff --git a/mcp-json-jackson2/src/main/java/io/modelcontextprotocol/json/schema/jackson/JacksonJsonSchemaValidatorSupplier.java b/mcp-json-jackson2/src/main/java/io/modelcontextprotocol/json/schema/jackson/JacksonJsonSchemaValidatorSupplier.java new file mode 100644 index 000000000..86153a538 --- /dev/null +++ b/mcp-json-jackson2/src/main/java/io/modelcontextprotocol/json/schema/jackson/JacksonJsonSchemaValidatorSupplier.java @@ -0,0 +1,29 @@ +/* + * Copyright 2025 - 2025 the original author or authors. + */ + +package io.modelcontextprotocol.json.schema.jackson; + +import io.modelcontextprotocol.json.schema.JsonSchemaValidator; +import io.modelcontextprotocol.json.schema.JsonSchemaValidatorSupplier; + +/** + * A concrete implementation of {@link JsonSchemaValidatorSupplier} that provides a + * {@link JsonSchemaValidator} instance based on the Jackson library. + * + * @see JsonSchemaValidatorSupplier + * @see JsonSchemaValidator + */ +public class JacksonJsonSchemaValidatorSupplier implements JsonSchemaValidatorSupplier { + + /** + * Returns a new instance of {@link JsonSchemaValidator} that uses the Jackson library + * for JSON schema validation. + * @return A {@link JsonSchemaValidator} instance. + */ + @Override + public JsonSchemaValidator get() { + return new DefaultJsonSchemaValidator(); + } + +} diff --git a/mcp-json-jackson2/src/main/resources/META-INF/services/io.modelcontextprotocol.json.McpJsonMapperSupplier b/mcp-json-jackson2/src/main/resources/META-INF/services/io.modelcontextprotocol.json.McpJsonMapperSupplier new file mode 100644 index 000000000..8ea66d698 --- /dev/null +++ b/mcp-json-jackson2/src/main/resources/META-INF/services/io.modelcontextprotocol.json.McpJsonMapperSupplier @@ -0,0 +1 @@ +io.modelcontextprotocol.json.jackson.JacksonMcpJsonMapperSupplier \ No newline at end of file diff --git a/mcp-json-jackson2/src/main/resources/META-INF/services/io.modelcontextprotocol.json.schema.JsonSchemaValidatorSupplier b/mcp-json-jackson2/src/main/resources/META-INF/services/io.modelcontextprotocol.json.schema.JsonSchemaValidatorSupplier new file mode 100644 index 000000000..0fb0b7e5a --- /dev/null +++ b/mcp-json-jackson2/src/main/resources/META-INF/services/io.modelcontextprotocol.json.schema.JsonSchemaValidatorSupplier @@ -0,0 +1 @@ +io.modelcontextprotocol.json.schema.jackson.JacksonJsonSchemaValidatorSupplier \ No newline at end of file diff --git a/mcp-json-jackson2/src/test/java/io/modelcontextprotocol/json/DefaultJsonSchemaValidatorTests.java b/mcp-json-jackson2/src/test/java/io/modelcontextprotocol/json/DefaultJsonSchemaValidatorTests.java new file mode 100644 index 000000000..7642f0480 --- /dev/null +++ b/mcp-json-jackson2/src/test/java/io/modelcontextprotocol/json/DefaultJsonSchemaValidatorTests.java @@ -0,0 +1,808 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.json; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; + +import java.util.List; +import java.util.Map; +import java.util.stream.Stream; + +import io.modelcontextprotocol.json.schema.jackson.DefaultJsonSchemaValidator; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; + +import io.modelcontextprotocol.json.schema.JsonSchemaValidator.ValidationResponse; + +/** + * Tests for {@link DefaultJsonSchemaValidator}. + * + * @author Christian Tzolov + */ +class DefaultJsonSchemaValidatorTests { + + private DefaultJsonSchemaValidator validator; + + private ObjectMapper objectMapper; + + @Mock + private ObjectMapper mockObjectMapper; + + @BeforeEach + void setUp() { + MockitoAnnotations.openMocks(this); + validator = new DefaultJsonSchemaValidator(); + objectMapper = new ObjectMapper(); + } + + /** + * Utility method to convert JSON string to Map + */ + private Map toMap(String json) { + try { + return objectMapper.readValue(json, new TypeReference>() { + }); + } + catch (Exception e) { + throw new RuntimeException("Failed to parse JSON: " + json, e); + } + } + + private List> toListMap(String json) { + try { + return objectMapper.readValue(json, new TypeReference>>() { + }); + } + catch (Exception e) { + throw new RuntimeException("Failed to parse JSON: " + json, e); + } + } + + @Test + void testDefaultConstructor() { + DefaultJsonSchemaValidator defaultValidator = new DefaultJsonSchemaValidator(); + + String schemaJson = """ + { + "type": "object", + "properties": { + "test": {"type": "string"} + } + } + """; + String contentJson = """ + { + "test": "value" + } + """; + + ValidationResponse response = defaultValidator.validate(toMap(schemaJson), toMap(contentJson)); + assertTrue(response.valid()); + } + + @Test + void testConstructorWithObjectMapper() { + ObjectMapper customMapper = new ObjectMapper(); + DefaultJsonSchemaValidator customValidator = new DefaultJsonSchemaValidator(customMapper); + + String schemaJson = """ + { + "type": "object", + "properties": { + "test": {"type": "string"} + } + } + """; + String contentJson = """ + { + "test": "value" + } + """; + + ValidationResponse response = customValidator.validate(toMap(schemaJson), toMap(contentJson)); + assertTrue(response.valid()); + } + + @Test + void testValidateWithValidStringSchema() { + String schemaJson = """ + { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"} + }, + "required": ["name", "age"] + } + """; + + String contentJson = """ + { + "name": "John Doe", + "age": 30 + } + """; + + Map schema = toMap(schemaJson); + Map structuredContent = toMap(contentJson); + + ValidationResponse response = validator.validate(schema, structuredContent); + + assertTrue(response.valid()); + assertNull(response.errorMessage()); + assertNotNull(response.jsonStructuredOutput()); + } + + @Test + void testValidateWithValidNumberSchema() { + String schemaJson = """ + { + "type": "object", + "properties": { + "price": {"type": "number", "minimum": 0}, + "quantity": {"type": "integer", "minimum": 1} + }, + "required": ["price", "quantity"] + } + """; + + String contentJson = """ + { + "price": 19.99, + "quantity": 5 + } + """; + + Map schema = toMap(schemaJson); + Map structuredContent = toMap(contentJson); + + ValidationResponse response = validator.validate(schema, structuredContent); + + assertTrue(response.valid()); + assertNull(response.errorMessage()); + } + + @Test + void testValidateWithValidArraySchema() { + String schemaJson = """ + { + "type": "object", + "properties": { + "items": { + "type": "array", + "items": {"type": "string"} + } + }, + "required": ["items"] + } + """; + + String contentJson = """ + { + "items": ["apple", "banana", "cherry"] + } + """; + + Map schema = toMap(schemaJson); + Map structuredContent = toMap(contentJson); + + ValidationResponse response = validator.validate(schema, structuredContent); + + assertTrue(response.valid()); + assertNull(response.errorMessage()); + } + + @Test + void testValidateWithValidArraySchemaTopLevelArray() { + String schemaJson = """ + { + "$schema" : "https://json-schema.org/draft/2020-12/schema", + "type" : "array", + "items" : { + "type" : "object", + "properties" : { + "city" : { + "type" : "string" + }, + "summary" : { + "type" : "string" + }, + "temperatureC" : { + "type" : "number", + "format" : "float" + } + }, + "required" : [ "city", "summary", "temperatureC" ] + }, + "additionalProperties" : false + } + """; + + String contentJson = """ + [ + { + "city": "London", + "summary": "Generally mild with frequent rainfall. Winters are cool and damp, summers are warm but rarely hot. Cloudy conditions are common throughout the year.", + "temperatureC": 11.3 + }, + { + "city": "New York", + "summary": "Four distinct seasons with hot and humid summers, cold winters with snow, and mild springs and autumns. Precipitation is fairly evenly distributed throughout the year.", + "temperatureC": 12.8 + }, + { + "city": "San Francisco", + "summary": "Mild year-round with a distinctive Mediterranean climate. Famous for summer fog, mild winters, and little temperature variation throughout the year. Very little rainfall in summer months.", + "temperatureC": 14.6 + }, + { + "city": "Tokyo", + "summary": "Humid subtropical climate with hot, wet summers and mild winters. Experiences a rainy season in early summer and occasional typhoons in late summer to early autumn.", + "temperatureC": 15.4 + } + ] + """; + + Map schema = toMap(schemaJson); + + // Validate as JSON string + ValidationResponse response = validator.validate(schema, contentJson); + + assertTrue(response.valid()); + assertNull(response.errorMessage()); + + List> structuredContent = toListMap(contentJson); + + // Validate as List> + response = validator.validate(schema, structuredContent); + + assertTrue(response.valid()); + assertNull(response.errorMessage()); + } + + @Test + void testValidateWithInvalidTypeSchema() { + String schemaJson = """ + { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"} + }, + "required": ["name", "age"] + } + """; + + String contentJson = """ + { + "name": "John Doe", + "age": "thirty" + } + """; + + Map schema = toMap(schemaJson); + Map structuredContent = toMap(contentJson); + + ValidationResponse response = validator.validate(schema, structuredContent); + + assertFalse(response.valid()); + assertNotNull(response.errorMessage()); + assertTrue(response.errorMessage().contains("Validation failed")); + assertTrue(response.errorMessage().contains("structuredContent does not match tool outputSchema")); + } + + @Test + void testValidateWithMissingRequiredField() { + String schemaJson = """ + { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"} + }, + "required": ["name", "age"] + } + """; + + String contentJson = """ + { + "name": "John Doe" + } + """; + + Map schema = toMap(schemaJson); + Map structuredContent = toMap(contentJson); + + ValidationResponse response = validator.validate(schema, structuredContent); + + assertFalse(response.valid()); + assertNotNull(response.errorMessage()); + assertTrue(response.errorMessage().contains("Validation failed")); + } + + @Test + void testValidateWithAdditionalPropertiesNotAllowed() { + String schemaJson = """ + { + "type": "object", + "properties": { + "name": {"type": "string"} + }, + "required": ["name"], + "additionalProperties": false + } + """; + + String contentJson = """ + { + "name": "John Doe", + "extraField": "should not be allowed" + } + """; + + Map schema = toMap(schemaJson); + Map structuredContent = toMap(contentJson); + + ValidationResponse response = validator.validate(schema, structuredContent); + + assertFalse(response.valid()); + assertNotNull(response.errorMessage()); + assertTrue(response.errorMessage().contains("Validation failed")); + } + + @Test + void testValidateWithAdditionalPropertiesExplicitlyAllowed() { + String schemaJson = """ + { + "type": "object", + "properties": { + "name": {"type": "string"} + }, + "required": ["name"], + "additionalProperties": true + } + """; + + String contentJson = """ + { + "name": "John Doe", + "extraField": "should be allowed" + } + """; + + Map schema = toMap(schemaJson); + Map structuredContent = toMap(contentJson); + + ValidationResponse response = validator.validate(schema, structuredContent); + + assertTrue(response.valid()); + assertNull(response.errorMessage()); + } + + @Test + void testValidateWithDefaultAdditionalProperties() { + String schemaJson = """ + { + "type": "object", + "properties": { + "name": {"type": "string"} + }, + "required": ["name"], + "additionalProperties": true + } + """; + + String contentJson = """ + { + "name": "John Doe", + "extraField": "should be allowed" + } + """; + + Map schema = toMap(schemaJson); + Map structuredContent = toMap(contentJson); + + ValidationResponse response = validator.validate(schema, structuredContent); + + assertTrue(response.valid()); + assertNull(response.errorMessage()); + } + + @Test + void testValidateWithAdditionalPropertiesExplicitlyDisallowed() { + String schemaJson = """ + { + "type": "object", + "properties": { + "name": {"type": "string"} + }, + "required": ["name"], + "additionalProperties": false + } + """; + + String contentJson = """ + { + "name": "John Doe", + "extraField": "should not be allowed" + } + """; + + Map schema = toMap(schemaJson); + Map structuredContent = toMap(contentJson); + + ValidationResponse response = validator.validate(schema, structuredContent); + + assertFalse(response.valid()); + assertNotNull(response.errorMessage()); + assertTrue(response.errorMessage().contains("Validation failed")); + } + + @Test + void testValidateWithEmptySchema() { + String schemaJson = """ + { + "additionalProperties": true + } + """; + + String contentJson = """ + { + "anything": "goes" + } + """; + + Map schema = toMap(schemaJson); + Map structuredContent = toMap(contentJson); + + ValidationResponse response = validator.validate(schema, structuredContent); + + assertTrue(response.valid()); + assertNull(response.errorMessage()); + } + + @Test + void testValidateWithEmptyContent() { + String schemaJson = """ + { + "type": "object", + "properties": {} + } + """; + + String contentJson = """ + {} + """; + + Map schema = toMap(schemaJson); + Map structuredContent = toMap(contentJson); + + ValidationResponse response = validator.validate(schema, structuredContent); + + assertTrue(response.valid()); + assertNull(response.errorMessage()); + } + + @Test + void testValidateWithNestedObjectSchema() { + String schemaJson = """ + { + "type": "object", + "properties": { + "person": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "address": { + "type": "object", + "properties": { + "street": {"type": "string"}, + "city": {"type": "string"} + }, + "required": ["street", "city"] + } + }, + "required": ["name", "address"] + } + }, + "required": ["person"] + } + """; + + String contentJson = """ + { + "person": { + "name": "John Doe", + "address": { + "street": "123 Main St", + "city": "Anytown" + } + } + } + """; + + Map schema = toMap(schemaJson); + Map structuredContent = toMap(contentJson); + + ValidationResponse response = validator.validate(schema, structuredContent); + + assertTrue(response.valid()); + assertNull(response.errorMessage()); + } + + @Test + void testValidateWithInvalidNestedObjectSchema() { + String schemaJson = """ + { + "type": "object", + "properties": { + "person": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "address": { + "type": "object", + "properties": { + "street": {"type": "string"}, + "city": {"type": "string"} + }, + "required": ["street", "city"] + } + }, + "required": ["name", "address"] + } + }, + "required": ["person"] + } + """; + + String contentJson = """ + { + "person": { + "name": "John Doe", + "address": { + "street": "123 Main St" + } + } + } + """; + + Map schema = toMap(schemaJson); + Map structuredContent = toMap(contentJson); + + ValidationResponse response = validator.validate(schema, structuredContent); + + assertFalse(response.valid()); + assertNotNull(response.errorMessage()); + assertTrue(response.errorMessage().contains("Validation failed")); + } + + @Test + void testValidateWithJsonProcessingException() throws Exception { + DefaultJsonSchemaValidator validatorWithMockMapper = new DefaultJsonSchemaValidator(mockObjectMapper); + + Map schema = Map.of("type", "object"); + Map structuredContent = Map.of("key", "value"); + + // This will trigger our null check and throw JsonProcessingException + when(mockObjectMapper.valueToTree(any())).thenReturn(null); + + ValidationResponse response = validatorWithMockMapper.validate(schema, structuredContent); + + assertFalse(response.valid()); + assertNotNull(response.errorMessage()); + assertTrue(response.errorMessage().contains("Error parsing tool JSON Schema")); + assertTrue(response.errorMessage().contains("Failed to convert schema to JsonNode")); + } + + @ParameterizedTest + @MethodSource("provideValidSchemaAndContentPairs") + void testValidateWithVariousValidInputs(Map schema, Map content) { + ValidationResponse response = validator.validate(schema, content); + + assertTrue(response.valid(), "Expected validation to pass for schema: " + schema + " and content: " + content); + assertNull(response.errorMessage()); + } + + @ParameterizedTest + @MethodSource("provideInvalidSchemaAndContentPairs") + void testValidateWithVariousInvalidInputs(Map schema, Map content) { + ValidationResponse response = validator.validate(schema, content); + + assertFalse(response.valid(), "Expected validation to fail for schema: " + schema + " and content: " + content); + assertNotNull(response.errorMessage()); + assertTrue(response.errorMessage().contains("Validation failed")); + } + + private static Map staticToMap(String json) { + try { + ObjectMapper mapper = new ObjectMapper(); + return mapper.readValue(json, new TypeReference>() { + }); + } + catch (Exception e) { + throw new RuntimeException("Failed to parse JSON: " + json, e); + } + } + + private static Stream provideValidSchemaAndContentPairs() { + return Stream.of( + // Boolean schema + Arguments.of(staticToMap(""" + { + "type": "object", + "properties": { + "flag": {"type": "boolean"} + } + } + """), staticToMap(""" + { + "flag": true + } + """)), + // String with additional properties allowed + Arguments.of(staticToMap(""" + { + "type": "object", + "properties": { + "name": {"type": "string"} + }, + "additionalProperties": true + } + """), staticToMap(""" + { + "name": "test", + "extra": "allowed" + } + """)), + // Array with specific items + Arguments.of(staticToMap(""" + { + "type": "object", + "properties": { + "numbers": { + "type": "array", + "items": {"type": "number"} + } + } + } + """), staticToMap(""" + { + "numbers": [1.0, 2.5, 3.14] + } + """)), + // Enum validation + Arguments.of(staticToMap(""" + { + "type": "object", + "properties": { + "status": { + "type": "string", + "enum": ["active", "inactive", "pending"] + } + } + } + """), staticToMap(""" + { + "status": "active" + } + """))); + } + + private static Stream provideInvalidSchemaAndContentPairs() { + return Stream.of( + // Wrong boolean type + Arguments.of(staticToMap(""" + { + "type": "object", + "properties": { + "flag": {"type": "boolean"} + } + } + """), staticToMap(""" + { + "flag": "true" + } + """)), + // Array with wrong item types + Arguments.of(staticToMap(""" + { + "type": "object", + "properties": { + "numbers": { + "type": "array", + "items": {"type": "number"} + } + } + } + """), staticToMap(""" + { + "numbers": ["one", "two", "three"] + } + """)), + // Invalid enum value + Arguments.of(staticToMap(""" + { + "type": "object", + "properties": { + "status": { + "type": "string", + "enum": ["active", "inactive", "pending"] + } + } + } + """), staticToMap(""" + { + "status": "unknown" + } + """)), + // Minimum constraint violation + Arguments.of(staticToMap(""" + { + "type": "object", + "properties": { + "age": {"type": "integer", "minimum": 0} + } + } + """), staticToMap(""" + { + "age": -5 + } + """))); + } + + @Test + void testValidationResponseToValid() { + String jsonOutput = "{\"test\":\"value\"}"; + ValidationResponse response = ValidationResponse.asValid(jsonOutput); + assertTrue(response.valid()); + assertNull(response.errorMessage()); + assertEquals(jsonOutput, response.jsonStructuredOutput()); + } + + @Test + void testValidationResponseToInvalid() { + String errorMessage = "Test error message"; + ValidationResponse response = ValidationResponse.asInvalid(errorMessage); + assertFalse(response.valid()); + assertEquals(errorMessage, response.errorMessage()); + assertNull(response.jsonStructuredOutput()); + } + + @Test + void testValidationResponseRecord() { + ValidationResponse response1 = new ValidationResponse(true, null, "{\"valid\":true}"); + ValidationResponse response2 = new ValidationResponse(false, "Error", null); + + assertTrue(response1.valid()); + assertNull(response1.errorMessage()); + assertEquals("{\"valid\":true}", response1.jsonStructuredOutput()); + + assertFalse(response2.valid()); + assertEquals("Error", response2.errorMessage()); + assertNull(response2.jsonStructuredOutput()); + + // Test equality + ValidationResponse response3 = new ValidationResponse(true, null, "{\"valid\":true}"); + assertEquals(response1, response3); + assertNotEquals(response1, response2); + } + +} diff --git a/mcp-json/pom.xml b/mcp-json/pom.xml new file mode 100644 index 000000000..2cbcf3516 --- /dev/null +++ b/mcp-json/pom.xml @@ -0,0 +1,39 @@ + + + 4.0.0 + + io.modelcontextprotocol.sdk + mcp-parent + 0.18.0-SNAPSHOT + + mcp-json + jar + Java MCP SDK JSON Support + Java MCP SDK JSON Support API + https://github.com/modelcontextprotocol/java-sdk + + https://github.com/modelcontextprotocol/java-sdk + git://github.com/modelcontextprotocol/java-sdk.git + git@github.com/modelcontextprotocol/java-sdk.git + + + + + org.apache.maven.plugins + maven-jar-plugin + + + + true + + + + + + + + + + diff --git a/mcp-json/src/main/java/io/modelcontextprotocol/json/McpJsonInternal.java b/mcp-json/src/main/java/io/modelcontextprotocol/json/McpJsonInternal.java new file mode 100644 index 000000000..31930ab33 --- /dev/null +++ b/mcp-json/src/main/java/io/modelcontextprotocol/json/McpJsonInternal.java @@ -0,0 +1,84 @@ +/* + * Copyright 2025 - 2025 the original author or authors. + */ + +package io.modelcontextprotocol.json; + +import java.util.ServiceLoader; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Stream; + +/** + * Utility class for creating a default {@link McpJsonMapper} instance. This class + * provides a single method to create a default mapper using the {@link ServiceLoader} + * mechanism. + */ +final class McpJsonInternal { + + private static McpJsonMapper defaultJsonMapper = null; + + /** + * Returns the cached default {@link McpJsonMapper} instance. If the default mapper + * has not been created yet, it will be initialized using the + * {@link #createDefaultMapper()} method. + * @return the default {@link McpJsonMapper} instance + * @throws IllegalStateException if no default {@link McpJsonMapper} implementation is + * found + */ + static McpJsonMapper getDefaultMapper() { + if (defaultJsonMapper == null) { + defaultJsonMapper = McpJsonInternal.createDefaultMapper(); + } + return defaultJsonMapper; + } + + /** + * Creates a default {@link McpJsonMapper} instance using the {@link ServiceLoader} + * mechanism. The default mapper is resolved by loading the first available + * {@link McpJsonMapperSupplier} implementation on the classpath. + * @return the default {@link McpJsonMapper} instance + * @throws IllegalStateException if no default {@link McpJsonMapper} implementation is + * found + */ + static McpJsonMapper createDefaultMapper() { + AtomicReference ex = new AtomicReference<>(); + return ServiceLoader.load(McpJsonMapperSupplier.class).stream().flatMap(p -> { + try { + McpJsonMapperSupplier supplier = p.get(); + return Stream.ofNullable(supplier); + } + catch (Exception e) { + addException(ex, e); + return Stream.empty(); + } + }).flatMap(jsonMapperSupplier -> { + try { + return Stream.ofNullable(jsonMapperSupplier.get()); + } + catch (Exception e) { + addException(ex, e); + return Stream.empty(); + } + }).findFirst().orElseThrow(() -> { + if (ex.get() != null) { + return ex.get(); + } + else { + return new IllegalStateException("No default McpJsonMapper implementation found"); + } + }); + } + + private static void addException(AtomicReference ref, Exception toAdd) { + ref.updateAndGet(existing -> { + if (existing == null) { + return new IllegalStateException("Failed to initialize default McpJsonMapper", toAdd); + } + else { + existing.addSuppressed(toAdd); + return existing; + } + }); + } + +} diff --git a/mcp-json/src/main/java/io/modelcontextprotocol/json/McpJsonMapper.java b/mcp-json/src/main/java/io/modelcontextprotocol/json/McpJsonMapper.java new file mode 100644 index 000000000..1e30cad16 --- /dev/null +++ b/mcp-json/src/main/java/io/modelcontextprotocol/json/McpJsonMapper.java @@ -0,0 +1,110 @@ +/* + * Copyright 2025 - 2025 the original author or authors. + */ + +package io.modelcontextprotocol.json; + +import java.io.IOException; + +/** + * Abstraction for JSON serialization/deserialization to decouple the SDK from any + * specific JSON library. A default implementation backed by Jackson is provided in + * io.modelcontextprotocol.spec.json.jackson.JacksonJsonMapper. + */ +public interface McpJsonMapper { + + /** + * Deserialize JSON string into a target type. + * @param content JSON as String + * @param type target class + * @return deserialized instance + * @param generic type + * @throws IOException on parse errors + */ + T readValue(String content, Class type) throws IOException; + + /** + * Deserialize JSON bytes into a target type. + * @param content JSON as bytes + * @param type target class + * @return deserialized instance + * @param generic type + * @throws IOException on parse errors + */ + T readValue(byte[] content, Class type) throws IOException; + + /** + * Deserialize JSON string into a parameterized target type. + * @param content JSON as String + * @param type parameterized type reference + * @return deserialized instance + * @param generic type + * @throws IOException on parse errors + */ + T readValue(String content, TypeRef type) throws IOException; + + /** + * Deserialize JSON bytes into a parameterized target type. + * @param content JSON as bytes + * @param type parameterized type reference + * @return deserialized instance + * @param generic type + * @throws IOException on parse errors + */ + T readValue(byte[] content, TypeRef type) throws IOException; + + /** + * Convert a value to a given type, useful for mapping nested JSON structures. + * @param fromValue source value + * @param type target class + * @return converted value + * @param generic type + */ + T convertValue(Object fromValue, Class type); + + /** + * Convert a value to a given parameterized type. + * @param fromValue source value + * @param type target type reference + * @return converted value + * @param generic type + */ + T convertValue(Object fromValue, TypeRef type); + + /** + * Serialize an object to JSON string. + * @param value object to serialize + * @return JSON as String + * @throws IOException on serialization errors + */ + String writeValueAsString(Object value) throws IOException; + + /** + * Serialize an object to JSON bytes. + * @param value object to serialize + * @return JSON as bytes + * @throws IOException on serialization errors + */ + byte[] writeValueAsBytes(Object value) throws IOException; + + /** + * Returns the default {@link McpJsonMapper}. + * @return The default {@link McpJsonMapper} + * @throws IllegalStateException If no {@link McpJsonMapper} implementation exists on + * the classpath. + */ + static McpJsonMapper getDefault() { + return McpJsonInternal.getDefaultMapper(); + } + + /** + * Creates a new default {@link McpJsonMapper}. + * @return The default {@link McpJsonMapper} + * @throws IllegalStateException If no {@link McpJsonMapper} implementation exists on + * the classpath. + */ + static McpJsonMapper createDefault() { + return McpJsonInternal.createDefaultMapper(); + } + +} diff --git a/mcp-json/src/main/java/io/modelcontextprotocol/json/McpJsonMapperSupplier.java b/mcp-json/src/main/java/io/modelcontextprotocol/json/McpJsonMapperSupplier.java new file mode 100644 index 000000000..619f96040 --- /dev/null +++ b/mcp-json/src/main/java/io/modelcontextprotocol/json/McpJsonMapperSupplier.java @@ -0,0 +1,14 @@ +/* + * Copyright 2025 - 2025 the original author or authors. + */ + +package io.modelcontextprotocol.json; + +import java.util.function.Supplier; + +/** + * Strategy interface for resolving a {@link McpJsonMapper}. + */ +public interface McpJsonMapperSupplier extends Supplier { + +} diff --git a/mcp-json/src/main/java/io/modelcontextprotocol/json/TypeRef.java b/mcp-json/src/main/java/io/modelcontextprotocol/json/TypeRef.java new file mode 100644 index 000000000..ab37b43f3 --- /dev/null +++ b/mcp-json/src/main/java/io/modelcontextprotocol/json/TypeRef.java @@ -0,0 +1,44 @@ +/* + * Copyright 2025 - 2025 the original author or authors. + */ + +package io.modelcontextprotocol.json; + +import java.lang.reflect.ParameterizedType; +import java.lang.reflect.Type; + +/** + * Captures generic type information at runtime for parameterized JSON (de)serialization. + * Usage: TypeRef<List<Foo>> ref = new TypeRef<>(){}; + */ +public abstract class TypeRef { + + private final Type type; + + /** + * Constructs a new TypeRef instance, capturing the generic type information of the + * subclass. This constructor should be called from an anonymous subclass to capture + * the actual type arguments. For example:

+	 * TypeRef<List<Foo>> ref = new TypeRef<>(){};
+	 * 
+ * @throws IllegalStateException if TypeRef is not subclassed with actual type + * information + */ + protected TypeRef() { + Type superClass = getClass().getGenericSuperclass(); + if (superClass instanceof Class) { + throw new IllegalStateException("TypeRef constructed without actual type information"); + } + this.type = ((ParameterizedType) superClass).getActualTypeArguments()[0]; + } + + /** + * Returns the captured type information. + * @return the Type representing the actual type argument captured by this TypeRef + * instance + */ + public Type getType() { + return type; + } + +} diff --git a/mcp-json/src/main/java/io/modelcontextprotocol/json/schema/JsonSchemaInternal.java b/mcp-json/src/main/java/io/modelcontextprotocol/json/schema/JsonSchemaInternal.java new file mode 100644 index 000000000..2497e7f80 --- /dev/null +++ b/mcp-json/src/main/java/io/modelcontextprotocol/json/schema/JsonSchemaInternal.java @@ -0,0 +1,83 @@ +/* + * Copyright 2025 - 2025 the original author or authors. + */ + +package io.modelcontextprotocol.json.schema; + +import java.util.ServiceLoader; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Stream; + +/** + * Internal utility class for creating a default {@link JsonSchemaValidator} instance. + * This class uses the {@link ServiceLoader} to discover and instantiate a + * {@link JsonSchemaValidatorSupplier} implementation. + */ +final class JsonSchemaInternal { + + private static JsonSchemaValidator defaultValidator = null; + + /** + * Returns the default {@link JsonSchemaValidator} instance. If the default validator + * has not been initialized, it will be created using the {@link ServiceLoader} to + * discover and instantiate a {@link JsonSchemaValidatorSupplier} implementation. + * @return The default {@link JsonSchemaValidator} instance. + * @throws IllegalStateException If no {@link JsonSchemaValidatorSupplier} + * implementation exists on the classpath or if an error occurs during instantiation. + */ + static JsonSchemaValidator getDefaultValidator() { + if (defaultValidator == null) { + defaultValidator = JsonSchemaInternal.createDefaultValidator(); + } + return defaultValidator; + } + + /** + * Creates a default {@link JsonSchemaValidator} instance by loading a + * {@link JsonSchemaValidatorSupplier} implementation using the {@link ServiceLoader}. + * @return A default {@link JsonSchemaValidator} instance. + * @throws IllegalStateException If no {@link JsonSchemaValidatorSupplier} + * implementation is found or if an error occurs during instantiation. + */ + static JsonSchemaValidator createDefaultValidator() { + AtomicReference ex = new AtomicReference<>(); + return ServiceLoader.load(JsonSchemaValidatorSupplier.class).stream().flatMap(p -> { + try { + JsonSchemaValidatorSupplier supplier = p.get(); + return Stream.ofNullable(supplier); + } + catch (Exception e) { + addException(ex, e); + return Stream.empty(); + } + }).flatMap(jsonMapperSupplier -> { + try { + return Stream.of(jsonMapperSupplier.get()); + } + catch (Exception e) { + addException(ex, e); + return Stream.empty(); + } + }).findFirst().orElseThrow(() -> { + if (ex.get() != null) { + return ex.get(); + } + else { + return new IllegalStateException("No default JsonSchemaValidatorSupplier implementation found"); + } + }); + } + + private static void addException(AtomicReference ref, Exception toAdd) { + ref.updateAndGet(existing -> { + if (existing == null) { + return new IllegalStateException("Failed to initialize default JsonSchemaValidatorSupplier", toAdd); + } + else { + existing.addSuppressed(toAdd); + return existing; + } + }); + } + +} diff --git a/mcp-json/src/main/java/io/modelcontextprotocol/json/schema/JsonSchemaValidator.java b/mcp-json/src/main/java/io/modelcontextprotocol/json/schema/JsonSchemaValidator.java new file mode 100644 index 000000000..8e35c0237 --- /dev/null +++ b/mcp-json/src/main/java/io/modelcontextprotocol/json/schema/JsonSchemaValidator.java @@ -0,0 +1,64 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ +package io.modelcontextprotocol.json.schema; + +import java.util.Map; + +/** + * Interface for validating structured content against a JSON schema. This interface + * defines a method to validate structured content based on the provided output schema. + * + * @author Christian Tzolov + */ +public interface JsonSchemaValidator { + + /** + * Represents the result of a validation operation. + * + * @param valid Indicates whether the validation was successful. + * @param errorMessage An error message if the validation failed, otherwise null. + * @param jsonStructuredOutput The text structured content in JSON format if the + * validation was successful, otherwise null. + */ + record ValidationResponse(boolean valid, String errorMessage, String jsonStructuredOutput) { + + public static ValidationResponse asValid(String jsonStructuredOutput) { + return new ValidationResponse(true, null, jsonStructuredOutput); + } + + public static ValidationResponse asInvalid(String message) { + return new ValidationResponse(false, message, null); + } + } + + /** + * Validates the structured content against the provided JSON schema. + * @param schema The JSON schema to validate against. + * @param structuredContent The structured content to validate. + * @return A ValidationResponse indicating whether the validation was successful or + * not. + */ + ValidationResponse validate(Map schema, Object structuredContent); + + /** + * Creates the default {@link JsonSchemaValidator}. + * @return The default {@link JsonSchemaValidator} + * @throws IllegalStateException If no {@link JsonSchemaValidator} implementation + * exists on the classpath. + */ + static JsonSchemaValidator createDefault() { + return JsonSchemaInternal.createDefaultValidator(); + } + + /** + * Returns the default {@link JsonSchemaValidator}. + * @return The default {@link JsonSchemaValidator} + * @throws IllegalStateException If no {@link JsonSchemaValidator} implementation + * exists on the classpath. + */ + static JsonSchemaValidator getDefault() { + return JsonSchemaInternal.getDefaultValidator(); + } + +} diff --git a/mcp-json/src/main/java/io/modelcontextprotocol/json/schema/JsonSchemaValidatorSupplier.java b/mcp-json/src/main/java/io/modelcontextprotocol/json/schema/JsonSchemaValidatorSupplier.java new file mode 100644 index 000000000..6f69169a0 --- /dev/null +++ b/mcp-json/src/main/java/io/modelcontextprotocol/json/schema/JsonSchemaValidatorSupplier.java @@ -0,0 +1,19 @@ +/* + * Copyright 2025 - 2025 the original author or authors. + */ + +package io.modelcontextprotocol.json.schema; + +import java.util.function.Supplier; + +/** + * A supplier interface that provides a {@link JsonSchemaValidator} instance. + * Implementations of this interface are expected to return a new or cached instance of + * {@link JsonSchemaValidator} when {@link #get()} is invoked. + * + * @see JsonSchemaValidator + * @see Supplier + */ +public interface JsonSchemaValidatorSupplier extends Supplier { + +} diff --git a/mcp-spring/mcp-spring-webflux/pom.xml b/mcp-spring/mcp-spring-webflux/pom.xml index 26452fe95..f1737a477 100644 --- a/mcp-spring/mcp-spring-webflux/pom.xml +++ b/mcp-spring/mcp-spring-webflux/pom.xml @@ -6,13 +6,13 @@ io.modelcontextprotocol.sdk mcp-parent - 0.11.0-SNAPSHOT + 0.18.0-SNAPSHOT ../../pom.xml mcp-spring-webflux jar - WebFlux implementation of the Java MCP SSE transport - + WebFlux transports + WebFlux implementation for the SSE and Streamable Http Client and Server transports https://github.com/modelcontextprotocol/java-sdk @@ -22,16 +22,22 @@ - + + io.modelcontextprotocol.sdk + mcp-json-jackson2 + 0.18.0-SNAPSHOT + + + io.modelcontextprotocol.sdk mcp - 0.11.0-SNAPSHOT + 0.18.0-SNAPSHOT io.modelcontextprotocol.sdk mcp-test - 0.11.0-SNAPSHOT + 0.18.0-SNAPSHOT test @@ -127,6 +133,13 @@ test + + net.javacrumbs.json-unit + json-unit-assertj + ${json-unit-assertj.version} + test + + diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java index dd7c65396..0b5ce55cd 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java @@ -1,6 +1,12 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + package io.modelcontextprotocol.client.transport; import java.io.IOException; +import java.util.Collections; +import java.util.Comparator; import java.util.List; import java.util.Optional; import java.util.concurrent.atomic.AtomicReference; @@ -18,18 +24,23 @@ import org.springframework.web.reactive.function.client.WebClient; import org.springframework.web.reactive.function.client.WebClientResponseException; -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; - +import io.modelcontextprotocol.client.McpAsyncClient; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.TypeRef; +import io.modelcontextprotocol.spec.ClosedMcpTransportSession; import io.modelcontextprotocol.spec.DefaultMcpTransportSession; import io.modelcontextprotocol.spec.DefaultMcpTransportStream; +import io.modelcontextprotocol.spec.HttpHeaders; import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpTransportException; import io.modelcontextprotocol.spec.McpTransportSession; import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException; import io.modelcontextprotocol.spec.McpTransportStream; +import io.modelcontextprotocol.spec.ProtocolVersions; import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.Utils; import reactor.core.Disposable; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -63,6 +74,8 @@ */ public class WebClientStreamableHttpTransport implements McpClientTransport { + private static final String MISSING_SESSION_ID = "[missing_session_id]"; + private static final Logger logger = LoggerFactory.getLogger(WebClientStreamableHttpTransport.class); private static final String DEFAULT_ENDPOINT = "/mcp"; @@ -76,7 +89,7 @@ public class WebClientStreamableHttpTransport implements McpClientTransport { private static final ParameterizedTypeReference> PARAMETERIZED_TYPE_REF = new ParameterizedTypeReference<>() { }; - private final ObjectMapper objectMapper; + private final McpJsonMapper jsonMapper; private final WebClient webClient; @@ -86,20 +99,35 @@ public class WebClientStreamableHttpTransport implements McpClientTransport { private final boolean resumableStreams; - private final AtomicReference activeSession = new AtomicReference<>(); + private final AtomicReference> activeSession = new AtomicReference<>(); private final AtomicReference, Mono>> handler = new AtomicReference<>(); private final AtomicReference> exceptionHandler = new AtomicReference<>(); - private WebClientStreamableHttpTransport(ObjectMapper objectMapper, WebClient.Builder webClientBuilder, - String endpoint, boolean resumableStreams, boolean openConnectionOnStartup) { - this.objectMapper = objectMapper; + private final List supportedProtocolVersions; + + private final String latestSupportedProtocolVersion; + + private WebClientStreamableHttpTransport(McpJsonMapper jsonMapper, WebClient.Builder webClientBuilder, + String endpoint, boolean resumableStreams, boolean openConnectionOnStartup, + List supportedProtocolVersions) { + this.jsonMapper = jsonMapper; this.webClient = webClientBuilder.build(); this.endpoint = endpoint; this.resumableStreams = resumableStreams; this.openConnectionOnStartup = openConnectionOnStartup; this.activeSession.set(createTransportSession()); + this.supportedProtocolVersions = List.copyOf(supportedProtocolVersions); + this.latestSupportedProtocolVersion = this.supportedProtocolVersions.stream() + .sorted(Comparator.reverseOrder()) + .findFirst() + .get(); + } + + @Override + public List protocolVersions() { + return supportedProtocolVersions; } /** @@ -125,18 +153,30 @@ public Mono connect(Function, Mono createTransportSession() { Function> onClose = sessionId -> sessionId == null ? Mono.empty() - : webClient.delete().uri(this.endpoint).headers(httpHeaders -> { - httpHeaders.add("mcp-session-id", sessionId); - }) + : webClient.delete() + .uri(this.endpoint) + .header(HttpHeaders.MCP_SESSION_ID, sessionId) + .header(HttpHeaders.PROTOCOL_VERSION, this.latestSupportedProtocolVersion) .retrieve() .toBodilessEntity() - .doOnError(e -> logger.warn("Got error when closing transport", e)) + .onErrorComplete(e -> { + logger.warn("Got error when closing transport", e); + return true; + }) .then(); return new DefaultMcpTransportSession(onClose); } + private McpTransportSession createClosedSession(McpTransportSession existingSession) { + var existingSessionId = Optional.ofNullable(existingSession) + .filter(session -> !(session instanceof ClosedMcpTransportSession)) + .flatMap(McpTransportSession::sessionId) + .orElse(null); + return new ClosedMcpTransportSession<>(existingSessionId); + } + @Override public void setExceptionHandler(Consumer handler) { logger.debug("Exception handler registered"); @@ -160,9 +200,9 @@ private void handleException(Throwable t) { public Mono closeGracefully() { return Mono.defer(() -> { logger.debug("Graceful close triggered"); - DefaultMcpTransportSession currentSession = this.activeSession.getAndSet(createTransportSession()); + McpTransportSession currentSession = this.activeSession.getAndUpdate(this::createClosedSession); if (currentSession != null) { - return currentSession.closeGracefully(); + return Mono.from(currentSession.closeGracefully()); } return Mono.empty(); }); @@ -186,10 +226,13 @@ private Mono reconnect(McpTransportStream stream) { Disposable connection = webClient.get() .uri(this.endpoint) .accept(MediaType.TEXT_EVENT_STREAM) + .header(HttpHeaders.PROTOCOL_VERSION, + ctx.getOrDefault(McpAsyncClient.NEGOTIATED_PROTOCOL_VERSION, + this.latestSupportedProtocolVersion)) .headers(httpHeaders -> { - transportSession.sessionId().ifPresent(id -> httpHeaders.add("mcp-session-id", id)); + transportSession.sessionId().ifPresent(id -> httpHeaders.add(HttpHeaders.MCP_SESSION_ID, id)); if (stream != null) { - stream.lastId().ifPresent(id -> httpHeaders.add("last-event-id", id)); + stream.lastId().ifPresent(id -> httpHeaders.add(HttpHeaders.LAST_EVENT_ID, id)); } }) .exchangeToFlux(response -> { @@ -202,8 +245,13 @@ else if (isNotAllowed(response)) { return Flux.empty(); } else if (isNotFound(response)) { - String sessionIdRepresentation = sessionIdOrPlaceholder(transportSession); - return mcpSessionNotFoundError(sessionIdRepresentation); + if (transportSession.sessionId().isPresent()) { + String sessionIdRepresentation = sessionIdOrPlaceholder(transportSession); + return mcpSessionNotFoundError(sessionIdRepresentation); + } + else { + return this.extractError(response, MISSING_SESSION_ID); + } } else { return response.createError().doOnError(e -> { @@ -243,16 +291,19 @@ public Mono sendMessage(McpSchema.JSONRPCMessage message) { final AtomicReference disposableRef = new AtomicReference<>(); final McpTransportSession transportSession = this.activeSession.get(); - Disposable connection = webClient.post() + Disposable connection = Flux.deferContextual(ctx -> webClient.post() .uri(this.endpoint) - .accept(MediaType.TEXT_EVENT_STREAM, MediaType.APPLICATION_JSON) + .accept(MediaType.APPLICATION_JSON, MediaType.TEXT_EVENT_STREAM) + .header(HttpHeaders.PROTOCOL_VERSION, + ctx.getOrDefault(McpAsyncClient.NEGOTIATED_PROTOCOL_VERSION, + this.latestSupportedProtocolVersion)) .headers(httpHeaders -> { - transportSession.sessionId().ifPresent(id -> httpHeaders.add("mcp-session-id", id)); + transportSession.sessionId().ifPresent(id -> httpHeaders.add(HttpHeaders.MCP_SESSION_ID, id)); }) .bodyValue(message) .exchangeToFlux(response -> { if (transportSession - .markInitialized(response.headers().asHttpHeaders().getFirst("mcp-session-id"))) { + .markInitialized(response.headers().asHttpHeaders().getFirst(HttpHeaders.MCP_SESSION_ID))) { // Once we have a session, we try to open an async stream for // the server to send notifications and requests out-of-band. reconnect(null).contextWrite(sink.contextView()).subscribe(); @@ -264,9 +315,11 @@ public Mono sendMessage(McpSchema.JSONRPCMessage message) { // 200 OK for notifications if (response.statusCode().is2xxSuccessful()) { Optional contentType = response.headers().contentType(); + long contentLength = response.headers().contentLength().orElse(-1); // Existing SDKs consume notifications with no response body nor // content type - if (contentType.isEmpty()) { + if (contentType.isEmpty() || contentLength == 0 + || response.statusCode().equals(HttpStatus.ACCEPTED)) { logger.trace("Message was successfully sent via POST for session {}", sessionRepresentation); // signal the caller that the message was successfully @@ -288,7 +341,7 @@ else if (mediaType.isCompatibleWith(MediaType.APPLICATION_JSON)) { logger.trace("Received response to POST for session {}", sessionRepresentation); // communicate to caller the message was delivered sink.success(); - return responseFlux(response); + return directResponseFlux(message, response); } else { logger.warn("Unknown media type {} returned for POST in session {}", contentType, @@ -298,19 +351,19 @@ else if (mediaType.isCompatibleWith(MediaType.APPLICATION_JSON)) { } } else { - if (isNotFound(response)) { + if (isNotFound(response) && !sessionRepresentation.equals(MISSING_SESSION_ID)) { return mcpSessionNotFoundError(sessionRepresentation); } - return extractError(response, sessionRepresentation); + return this.extractError(response, sessionRepresentation); } - }) + })) .flatMap(jsonRpcMessage -> this.handler.get().apply(Mono.just(jsonRpcMessage))) - .onErrorResume(t -> { + .onErrorComplete(t -> { // handle the error first this.handleException(t); // inform the caller of sendMessage sink.error(t); - return Flux.empty(); + return true; }) .doFinally(s -> { Disposable ref = disposableRef.getAndSet(null); @@ -338,13 +391,13 @@ private Flux extractError(ClientResponse response, Str McpSchema.JSONRPCResponse.JSONRPCError jsonRpcError = null; Exception toPropagate; try { - McpSchema.JSONRPCResponse jsonRpcResponse = objectMapper.readValue(body, - McpSchema.JSONRPCResponse.class); + McpSchema.JSONRPCResponse jsonRpcResponse = jsonMapper.readValue(body, McpSchema.JSONRPCResponse.class); jsonRpcError = jsonRpcResponse.error(); - toPropagate = new McpError(jsonRpcError); + toPropagate = jsonRpcError != null ? new McpError(jsonRpcError) + : new McpTransportException("Can't parse the jsonResponse " + jsonRpcResponse); } catch (IOException ex) { - toPropagate = new RuntimeException("Sending request failed", e); + toPropagate = new McpTransportException("Sending request failed, " + e.getMessage(), e); logger.debug("Received content together with {} HTTP code response: {}", response.statusCode(), body); } @@ -353,9 +406,13 @@ private Flux extractError(ClientResponse response, Str // invalidate the session // https://github.com/modelcontextprotocol/typescript-sdk/issues/389 if (responseException.getStatusCode().isSameCodeAs(HttpStatus.BAD_REQUEST)) { - return Mono.error(new McpTransportSessionNotFoundException(sessionRepresentation, toPropagate)); + if (!sessionRepresentation.equals(MISSING_SESSION_ID)) { + return Mono.error(new McpTransportSessionNotFoundException(sessionRepresentation, toPropagate)); + } + return Mono.error(new McpTransportException("Received 400 BAD REQUEST for session " + + sessionRepresentation + ". " + toPropagate.getMessage(), toPropagate)); } - return Mono.empty(); + return Mono.error(toPropagate); }).flux(); } @@ -382,18 +439,26 @@ private static boolean isEventStream(ClientResponse response) { } private static String sessionIdOrPlaceholder(McpTransportSession transportSession) { - return transportSession.sessionId().orElse("[missing_session_id]"); + return transportSession.sessionId().orElse(MISSING_SESSION_ID); } - private Flux responseFlux(ClientResponse response) { + private Flux directResponseFlux(McpSchema.JSONRPCMessage sentMessage, + ClientResponse response) { return response.bodyToMono(String.class).>handle((responseMessage, s) -> { try { - McpSchema.JSONRPCMessage jsonRpcResponse = McpSchema.deserializeJsonRpcMessage(objectMapper, - responseMessage); - s.next(List.of(jsonRpcResponse)); + if (sentMessage instanceof McpSchema.JSONRPCNotification) { + logger.warn("Notification: {} received non-compliant response: {}", sentMessage, + Utils.hasText(responseMessage) ? responseMessage : "[empty]"); + s.complete(); + } + else { + McpSchema.JSONRPCMessage jsonRpcResponse = McpSchema.deserializeJsonRpcMessage(jsonMapper, + responseMessage); + s.next(List.of(jsonRpcResponse)); + } } catch (IOException e) { - s.error(e); + s.error(new McpTransportException(e)); } }).flatMapIterable(Function.identity()); } @@ -407,8 +472,8 @@ private Flux newEventStream(ClientResponse response, S } @Override - public T unmarshalFrom(Object data, TypeReference typeRef) { - return this.objectMapper.convertValue(data, typeRef); + public T unmarshalFrom(Object data, TypeRef typeRef) { + return this.jsonMapper.convertValue(data, typeRef); } private Tuple2, Iterable> parse(ServerSentEvent event) { @@ -416,15 +481,16 @@ private Tuple2, Iterable> parse(Serve try { // We don't support batching ATM and probably won't since the next version // considers removing it. - McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.objectMapper, event.data()); + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.jsonMapper, event.data()); return Tuples.of(Optional.ofNullable(event.id()), List.of(message)); } catch (IOException ioException) { - throw new McpError("Error parsing JSON-RPC message: " + event.data()); + throw new McpTransportException("Error parsing JSON-RPC message: " + event.data(), ioException); } } else { - throw new McpError("Received unrecognized SSE event type: " + event.event()); + logger.debug("Received SSE event with type: {}", event); + return Tuples.of(Optional.empty(), List.of()); } } @@ -433,7 +499,7 @@ private Tuple2, Iterable> parse(Serve */ public static class Builder { - private ObjectMapper objectMapper; + private McpJsonMapper jsonMapper; private WebClient.Builder webClientBuilder; @@ -443,19 +509,22 @@ public static class Builder { private boolean openConnectionOnStartup = false; + private List supportedProtocolVersions = List.of(ProtocolVersions.MCP_2024_11_05, + ProtocolVersions.MCP_2025_03_26, ProtocolVersions.MCP_2025_06_18); + private Builder(WebClient.Builder webClientBuilder) { Assert.notNull(webClientBuilder, "WebClient.Builder must not be null"); this.webClientBuilder = webClientBuilder; } /** - * Configure the {@link ObjectMapper} to use. - * @param objectMapper instance to use + * Configure the {@link McpJsonMapper} to use. + * @param jsonMapper instance to use * @return the builder instance */ - public Builder objectMapper(ObjectMapper objectMapper) { - Assert.notNull(objectMapper, "ObjectMapper must not be null"); - this.objectMapper = objectMapper; + public Builder jsonMapper(McpJsonMapper jsonMapper) { + Assert.notNull(jsonMapper, "JsonMapper must not be null"); + this.jsonMapper = jsonMapper; return this; } @@ -508,16 +577,38 @@ public Builder openConnectionOnStartup(boolean openConnectionOnStartup) { return this; } + /** + * Sets the list of supported protocol versions used in version negotiation. By + * default, the client will send the latest of those versions in the + * {@code MCP-Protocol-Version} header. + *

+ * Setting this value only updates the values used in version negotiation, and + * does NOT impact the actual capabilities of the transport. It should only be + * used for compatibility with servers having strict requirements around the + * {@code MCP-Protocol-Version} header. + * @param supportedProtocolVersions protocol versions supported by this transport + * @return this builder + * @see version + * negotiation specification + * @see Protocol + * Version Header + */ + public Builder supportedProtocolVersions(List supportedProtocolVersions) { + Assert.notEmpty(supportedProtocolVersions, "supportedProtocolVersions must not be empty"); + this.supportedProtocolVersions = Collections.unmodifiableList(supportedProtocolVersions); + return this; + } + /** * Construct a fresh instance of {@link WebClientStreamableHttpTransport} using * the current builder configuration. * @return a new instance of {@link WebClientStreamableHttpTransport} */ public WebClientStreamableHttpTransport build() { - ObjectMapper objectMapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); - - return new WebClientStreamableHttpTransport(objectMapper, this.webClientBuilder, endpoint, resumableStreams, - openConnectionOnStartup); + return new WebClientStreamableHttpTransport(jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, + webClientBuilder, endpoint, resumableStreams, openConnectionOnStartup, supportedProtocolVersions); } } diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java index 128cda4c3..91b89d6d2 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java @@ -1,18 +1,22 @@ /* * Copyright 2024 - 2024 the original author or authors. */ + package io.modelcontextprotocol.client.transport; import java.io.IOException; +import java.util.List; import java.util.function.BiConsumer; import java.util.function.Function; -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.TypeRef; + +import io.modelcontextprotocol.spec.HttpHeaders; import io.modelcontextprotocol.spec.McpClientTransport; -import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; +import io.modelcontextprotocol.spec.ProtocolVersions; import io.modelcontextprotocol.util.Assert; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -62,6 +66,8 @@ public class WebFluxSseClientTransport implements McpClientTransport { private static final Logger logger = LoggerFactory.getLogger(WebFluxSseClientTransport.class); + private static final String MCP_PROTOCOL_VERSION = ProtocolVersions.MCP_2024_11_05; + /** * Event type for JSON-RPC messages received through the SSE connection. The server * sends messages with this event type to transmit JSON-RPC protocol data. @@ -94,10 +100,10 @@ public class WebFluxSseClientTransport implements McpClientTransport { private final WebClient webClient; /** - * ObjectMapper for serializing outbound messages and deserializing inbound messages. + * JSON mapper for serializing outbound messages and deserializing inbound messages. * Handles conversion between JSON-RPC messages and their string representation. */ - protected ObjectMapper objectMapper; + protected McpJsonMapper jsonMapper; /** * Subscription for the SSE connection handling inbound messages. Used for cleanup @@ -123,27 +129,16 @@ public class WebFluxSseClientTransport implements McpClientTransport { */ private String sseEndpoint; - /** - * Constructs a new SseClientTransport with the specified WebClient builder. Uses a - * default ObjectMapper instance for JSON processing. - * @param webClientBuilder the WebClient.Builder to use for creating the WebClient - * instance - * @throws IllegalArgumentException if webClientBuilder is null - */ - public WebFluxSseClientTransport(WebClient.Builder webClientBuilder) { - this(webClientBuilder, new ObjectMapper()); - } - /** * Constructs a new SseClientTransport with the specified WebClient builder and * ObjectMapper. Initializes both inbound and outbound message processing pipelines. * @param webClientBuilder the WebClient.Builder to use for creating the WebClient * instance - * @param objectMapper the ObjectMapper to use for JSON processing + * @param jsonMapper the ObjectMapper to use for JSON processing * @throws IllegalArgumentException if either parameter is null */ - public WebFluxSseClientTransport(WebClient.Builder webClientBuilder, ObjectMapper objectMapper) { - this(webClientBuilder, objectMapper, DEFAULT_SSE_ENDPOINT); + public WebFluxSseClientTransport(WebClient.Builder webClientBuilder, McpJsonMapper jsonMapper) { + this(webClientBuilder, jsonMapper, DEFAULT_SSE_ENDPOINT); } /** @@ -151,21 +146,25 @@ public WebFluxSseClientTransport(WebClient.Builder webClientBuilder, ObjectMappe * ObjectMapper. Initializes both inbound and outbound message processing pipelines. * @param webClientBuilder the WebClient.Builder to use for creating the WebClient * instance - * @param objectMapper the ObjectMapper to use for JSON processing + * @param jsonMapper the ObjectMapper to use for JSON processing * @param sseEndpoint the SSE endpoint URI to use for establishing the connection * @throws IllegalArgumentException if either parameter is null */ - public WebFluxSseClientTransport(WebClient.Builder webClientBuilder, ObjectMapper objectMapper, - String sseEndpoint) { - Assert.notNull(objectMapper, "ObjectMapper must not be null"); + public WebFluxSseClientTransport(WebClient.Builder webClientBuilder, McpJsonMapper jsonMapper, String sseEndpoint) { + Assert.notNull(jsonMapper, "jsonMapper must not be null"); Assert.notNull(webClientBuilder, "WebClient.Builder must not be null"); Assert.hasText(sseEndpoint, "SSE endpoint must not be null or empty"); - this.objectMapper = objectMapper; + this.jsonMapper = jsonMapper; this.webClient = webClientBuilder.build(); this.sseEndpoint = sseEndpoint; } + @Override + public List protocolVersions() { + return List.of(MCP_PROTOCOL_VERSION); + } + /** * Establishes a connection to the MCP server using Server-Sent Events (SSE). This * method initiates the SSE connection and sets up the message processing pipeline. @@ -185,8 +184,6 @@ public WebFluxSseClientTransport(WebClient.Builder webClientBuilder, ObjectMappe * @param handler a function that processes incoming JSON-RPC messages and returns * responses * @return a Mono that completes when the connection is fully established - * @throws McpError if there's an error processing SSE events or if an unrecognized - * event type is received */ @Override public Mono connect(Function, Mono> handler) { @@ -203,12 +200,12 @@ public Mono connect(Function, Mono> h else { // TODO: clarify with the spec if multiple events can be // received - s.error(new McpError("Failed to handle SSE endpoint event")); + s.error(new RuntimeException("Failed to handle SSE endpoint event")); } } else if (MESSAGE_EVENT_TYPE.equals(event.event())) { try { - JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.objectMapper, event.data()); + JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.jsonMapper, event.data()); s.next(message); } catch (IOException ioException) { @@ -216,7 +213,8 @@ else if (MESSAGE_EVENT_TYPE.equals(event.event())) { } } else { - s.error(new McpError("Received unrecognized SSE event type: " + event.event())); + logger.debug("Received unrecognized SSE event type: {}", event); + s.complete(); } }).transform(handler)).subscribe(); @@ -245,10 +243,11 @@ public Mono sendMessage(JSONRPCMessage message) { return Mono.empty(); } try { - String jsonText = this.objectMapper.writeValueAsString(message); + String jsonText = this.jsonMapper.writeValueAsString(message); return webClient.post() .uri(messageEndpointUri) .contentType(MediaType.APPLICATION_JSON) + .header(HttpHeaders.PROTOCOL_VERSION, MCP_PROTOCOL_VERSION) .bodyValue(jsonText) .retrieve() .toBodilessEntity() @@ -281,6 +280,7 @@ protected Flux> eventStream() {// @formatter:off .get() .uri(this.sseEndpoint) .accept(MediaType.TEXT_EVENT_STREAM) + .header(HttpHeaders.PROTOCOL_VERSION, MCP_PROTOCOL_VERSION) .retrieve() .bodyToFlux(SSE_TYPE) .retryWhen(Retry.from(retrySignal -> retrySignal.handle(inboundRetryHandler))); @@ -337,13 +337,13 @@ public Mono closeGracefully() { // @formatter:off * type conversion capabilities to handle complex object structures. * @param the target type to convert the data into * @param data the source object to convert - * @param typeRef the TypeReference describing the target type + * @param typeRef the TypeRef describing the target type * @return the unmarshalled object of type T * @throws IllegalArgumentException if the conversion cannot be performed */ @Override - public T unmarshalFrom(Object data, TypeReference typeRef) { - return this.objectMapper.convertValue(data, typeRef); + public T unmarshalFrom(Object data, TypeRef typeRef) { + return this.jsonMapper.convertValue(data, typeRef); } /** @@ -365,7 +365,7 @@ public static class Builder { private String sseEndpoint = DEFAULT_SSE_ENDPOINT; - private ObjectMapper objectMapper = new ObjectMapper(); + private McpJsonMapper jsonMapper; /** * Creates a new builder with the specified WebClient.Builder. @@ -388,13 +388,13 @@ public Builder sseEndpoint(String sseEndpoint) { } /** - * Sets the object mapper for JSON serialization/deserialization. - * @param objectMapper the object mapper + * Sets the JSON mapper for serialization/deserialization. + * @param jsonMapper the JsonMapper to use * @return this builder */ - public Builder objectMapper(ObjectMapper objectMapper) { - Assert.notNull(objectMapper, "objectMapper must not be null"); - this.objectMapper = objectMapper; + public Builder jsonMapper(McpJsonMapper jsonMapper) { + Assert.notNull(jsonMapper, "jsonMapper must not be null"); + this.jsonMapper = jsonMapper; return this; } @@ -403,7 +403,8 @@ public Builder objectMapper(ObjectMapper objectMapper) { * @return a new transport instance */ public WebFluxSseClientTransport build() { - return new WebFluxSseClientTransport(webClientBuilder, objectMapper, sseEndpoint); + return new WebFluxSseClientTransport(webClientBuilder, + jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, sseEndpoint); } } diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java index 62264d9aa..0c80c5b8b 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java @@ -1,17 +1,26 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + package io.modelcontextprotocol.server.transport; import java.io.IOException; -import java.util.Map; +import java.time.Duration; +import java.util.List; import java.util.concurrent.ConcurrentHashMap; -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.TypeRef; +import io.modelcontextprotocol.server.McpTransportContextExtractor; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpServerSession; import io.modelcontextprotocol.spec.McpServerTransport; import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.spec.ProtocolVersions; import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.KeepAliveScheduler; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.Exceptions; @@ -26,6 +35,7 @@ import org.springframework.web.reactive.function.server.RouterFunctions; import org.springframework.web.reactive.function.server.ServerRequest; import org.springframework.web.reactive.function.server.ServerResponse; +import org.springframework.web.util.UriComponentsBuilder; /** * Server-side implementation of the MCP (Model Context Protocol) HTTP transport using @@ -77,14 +87,18 @@ public class WebFluxSseServerTransportProvider implements McpServerTransportProv */ public static final String ENDPOINT_EVENT_TYPE = "endpoint"; + private static final String MCP_PROTOCOL_VERSION = "2025-06-18"; + /** * Default SSE endpoint path as specified by the MCP transport specification. */ public static final String DEFAULT_SSE_ENDPOINT = "/sse"; + public static final String SESSION_ID = "sessionId"; + public static final String DEFAULT_BASE_URL = ""; - private final ObjectMapper objectMapper; + private final McpJsonMapper jsonMapper; /** * Base URL for the message endpoint. This is used to construct the full URL for @@ -105,63 +119,67 @@ public class WebFluxSseServerTransportProvider implements McpServerTransportProv */ private final ConcurrentHashMap sessions = new ConcurrentHashMap<>(); + private McpTransportContextExtractor contextExtractor; + /** * Flag indicating if the transport is shutting down. */ private volatile boolean isClosing = false; /** - * Constructs a new WebFlux SSE server transport provider instance with the default - * SSE endpoint. - * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization - * of MCP messages. Must not be null. - * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC - * messages. This endpoint will be communicated to clients during SSE connection - * setup. Must not be null. - * @throws IllegalArgumentException if either parameter is null + * Keep-alive scheduler for managing session pings. Activated if keepAliveInterval is + * set. Disabled by default. */ - public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint) { - this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT); - } + private KeepAliveScheduler keepAliveScheduler; /** * Constructs a new WebFlux SSE server transport provider instance. - * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization - * of MCP messages. Must not be null. - * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC - * messages. This endpoint will be communicated to clients during SSE connection - * setup. Must not be null. - * @throws IllegalArgumentException if either parameter is null - */ - public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) { - this(objectMapper, DEFAULT_BASE_URL, messageEndpoint, sseEndpoint); - } - - /** - * Constructs a new WebFlux SSE server transport provider instance. - * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization - * of MCP messages. Must not be null. + * @param jsonMapper The ObjectMapper to use for JSON serialization/deserialization of + * MCP messages. Must not be null. * @param baseUrl webflux message base path * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC * messages. This endpoint will be communicated to clients during SSE connection * setup. Must not be null. + * @param sseEndpoint The SSE endpoint path. Must not be null. + * @param keepAliveInterval The interval for sending keep-alive pings to clients. + * @param contextExtractor The context extractor to use for extracting MCP transport + * context from HTTP requests. Must not be null. * @throws IllegalArgumentException if either parameter is null */ - public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, - String sseEndpoint) { - Assert.notNull(objectMapper, "ObjectMapper must not be null"); + private WebFluxSseServerTransportProvider(McpJsonMapper jsonMapper, String baseUrl, String messageEndpoint, + String sseEndpoint, Duration keepAliveInterval, + McpTransportContextExtractor contextExtractor) { + Assert.notNull(jsonMapper, "ObjectMapper must not be null"); Assert.notNull(baseUrl, "Message base path must not be null"); Assert.notNull(messageEndpoint, "Message endpoint must not be null"); Assert.notNull(sseEndpoint, "SSE endpoint must not be null"); + Assert.notNull(contextExtractor, "Context extractor must not be null"); - this.objectMapper = objectMapper; + this.jsonMapper = jsonMapper; this.baseUrl = baseUrl; this.messageEndpoint = messageEndpoint; this.sseEndpoint = sseEndpoint; + this.contextExtractor = contextExtractor; this.routerFunction = RouterFunctions.route() .GET(this.sseEndpoint, this::handleSseConnection) .POST(this.messageEndpoint, this::handleMessage) .build(); + + if (keepAliveInterval != null) { + + this.keepAliveScheduler = KeepAliveScheduler + .builder(() -> (isClosing) ? Flux.empty() : Flux.fromIterable(sessions.values())) + .initialDelay(keepAliveInterval) + .interval(keepAliveInterval) + .build(); + + this.keepAliveScheduler.start(); + } + } + + @Override + public List protocolVersions() { + return List.of(ProtocolVersions.MCP_2024_11_05); } @Override @@ -207,18 +225,10 @@ public Mono notifyClients(String method, Object params) { // FIXME: This javadoc makes claims about using isClosing flag but it's not // actually // doing that. + /** * Initiates a graceful shutdown of all the sessions. This method ensures all active * sessions are properly closed and cleaned up. - * - *

- * The shutdown process: - *

    - *
  • Marks the transport as closing to prevent new connections
  • - *
  • Closes each active session
  • - *
  • Removes closed sessions from the sessions map
  • - *
  • Times out after 5 seconds if shutdown takes too long
  • - *
* @return A Mono that completes when all sessions have been closed */ @Override @@ -226,7 +236,14 @@ public Mono closeGracefully() { return Flux.fromIterable(sessions.values()) .doFirst(() -> logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size())) .flatMap(McpServerSession::closeGracefully) - .then(); + .then() + .doOnSuccess(v -> { + logger.debug("Graceful shutdown completed"); + sessions.clear(); + if (this.keepAliveScheduler != null) { + this.keepAliveScheduler.shutdown(); + } + }); } /** @@ -256,6 +273,8 @@ private Mono handleSseConnection(ServerRequest request) { return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down"); } + McpTransportContext transportContext = this.contextExtractor.extract(request); + return ServerResponse.ok() .contentType(MediaType.TEXT_EVENT_STREAM) .body(Flux.>create(sink -> { @@ -269,15 +288,28 @@ private Mono handleSseConnection(ServerRequest request) { // Send initial endpoint event logger.debug("Sending initial endpoint event to session: {}", sessionId); - sink.next(ServerSentEvent.builder() - .event(ENDPOINT_EVENT_TYPE) - .data(this.baseUrl + this.messageEndpoint + "?sessionId=" + sessionId) - .build()); + sink.next( + ServerSentEvent.builder().event(ENDPOINT_EVENT_TYPE).data(buildEndpointUrl(sessionId)).build()); sink.onCancel(() -> { logger.debug("Session {} cancelled", sessionId); sessions.remove(sessionId); }); - }), ServerSentEvent.class); + }).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)), ServerSentEvent.class); + } + + /** + * Constructs the full message endpoint URL by combining the base URL, message path, + * and the required session_id query parameter. + * @param sessionId the unique session identifier + * @return the fully qualified endpoint URL as a string + */ + private String buildEndpointUrl(String sessionId) { + // for WebMVC compatibility + return UriComponentsBuilder.fromUriString(this.baseUrl) + .path(this.messageEndpoint) + .queryParam(SESSION_ID, sessionId) + .build() + .toUriString(); } /** @@ -311,9 +343,11 @@ private Mono handleMessage(ServerRequest request) { .bodyValue(new McpError("Session not found: " + request.queryParam("sessionId").get())); } + McpTransportContext transportContext = this.contextExtractor.extract(request); + return request.bodyToMono(String.class).flatMap(body -> { try { - McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body); + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(jsonMapper, body); return session.handle(message).flatMap(response -> ServerResponse.ok().build()).onErrorResume(error -> { logger.error("Error processing message: {}", error.getMessage()); // TODO: instead of signalling the error, just respond with 200 OK @@ -327,7 +361,7 @@ private Mono handleMessage(ServerRequest request) { logger.error("Failed to deserialize message: {}", e.getMessage()); return ServerResponse.badRequest().bodyValue(new McpError("Invalid message format")); } - }); + }).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)); } private class WebFluxMcpSessionTransport implements McpServerTransport { @@ -342,7 +376,7 @@ public WebFluxMcpSessionTransport(FluxSink> sink) { public Mono sendMessage(McpSchema.JSONRPCMessage message) { return Mono.fromSupplier(() -> { try { - return objectMapper.writeValueAsString(message); + return jsonMapper.writeValueAsString(message); } catch (IOException e) { throw Exceptions.propagate(e); @@ -361,8 +395,8 @@ public Mono sendMessage(McpSchema.JSONRPCMessage message) { } @Override - public T unmarshalFrom(Object data, TypeReference typeRef) { - return objectMapper.convertValue(data, typeRef); + public T unmarshalFrom(Object data, TypeRef typeRef) { + return jsonMapper.convertValue(data, typeRef); } @Override @@ -389,7 +423,7 @@ public static Builder builder() { */ public static class Builder { - private ObjectMapper objectMapper; + private McpJsonMapper jsonMapper; private String baseUrl = DEFAULT_BASE_URL; @@ -397,16 +431,21 @@ public static class Builder { private String sseEndpoint = DEFAULT_SSE_ENDPOINT; + private Duration keepAliveInterval; + + private McpTransportContextExtractor contextExtractor = ( + serverRequest) -> McpTransportContext.EMPTY; + /** - * Sets the ObjectMapper to use for JSON serialization/deserialization of MCP + * Sets the McpJsonMapper to use for JSON serialization/deserialization of MCP * messages. - * @param objectMapper The ObjectMapper instance. Must not be null. + * @param jsonMapper The McpJsonMapper instance. Must not be null. * @return this builder instance - * @throws IllegalArgumentException if objectMapper is null + * @throws IllegalArgumentException if jsonMapper is null */ - public Builder objectMapper(ObjectMapper objectMapper) { - Assert.notNull(objectMapper, "ObjectMapper must not be null"); - this.objectMapper = objectMapper; + public Builder jsonMapper(McpJsonMapper jsonMapper) { + Assert.notNull(jsonMapper, "JsonMapper must not be null"); + this.jsonMapper = jsonMapper; return this; } @@ -447,6 +486,33 @@ public Builder sseEndpoint(String sseEndpoint) { return this; } + /** + * Sets the interval for sending keep-alive pings to clients. + * @param keepAliveInterval The keep-alive interval duration. If null, keep-alive + * is disabled. + * @return this builder instance + */ + public Builder keepAliveInterval(Duration keepAliveInterval) { + this.keepAliveInterval = keepAliveInterval; + return this; + } + + /** + * Sets the context extractor that allows providing the MCP feature + * implementations to inspect HTTP transport level metadata that was present at + * HTTP request processing time. This allows to extract custom headers and other + * useful data for use during execution later on in the process. + * @param contextExtractor The contextExtractor to fill in a + * {@link McpTransportContext}. + * @return this builder instance + * @throws IllegalArgumentException if contextExtractor is null + */ + public Builder contextExtractor(McpTransportContextExtractor contextExtractor) { + Assert.notNull(contextExtractor, "contextExtractor must not be null"); + this.contextExtractor = contextExtractor; + return this; + } + /** * Builds a new instance of {@link WebFluxSseServerTransportProvider} with the * configured settings. @@ -454,10 +520,9 @@ public Builder sseEndpoint(String sseEndpoint) { * @throws IllegalStateException if required parameters are not set */ public WebFluxSseServerTransportProvider build() { - Assert.notNull(objectMapper, "ObjectMapper must be set"); Assert.notNull(messageEndpoint, "Message endpoint must be set"); - - return new WebFluxSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint); + return new WebFluxSseServerTransportProvider(jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, + baseUrl, messageEndpoint, sseEndpoint, keepAliveInterval, contextExtractor); } } diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStatelessServerTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStatelessServerTransport.java new file mode 100644 index 000000000..400be341e --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStatelessServerTransport.java @@ -0,0 +1,222 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.server.McpStatelessServerHandler; +import io.modelcontextprotocol.server.McpTransportContextExtractor; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpStatelessServerTransport; +import io.modelcontextprotocol.util.Assert; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.web.reactive.function.server.RouterFunction; +import org.springframework.web.reactive.function.server.RouterFunctions; +import org.springframework.web.reactive.function.server.ServerRequest; +import org.springframework.web.reactive.function.server.ServerResponse; +import reactor.core.publisher.Mono; + +import java.io.IOException; +import java.util.List; + +/** + * Implementation of a WebFlux based {@link McpStatelessServerTransport}. + * + * @author Dariusz Jędrzejczyk + */ +public class WebFluxStatelessServerTransport implements McpStatelessServerTransport { + + private static final Logger logger = LoggerFactory.getLogger(WebFluxStatelessServerTransport.class); + + private final McpJsonMapper jsonMapper; + + private final String mcpEndpoint; + + private final RouterFunction routerFunction; + + private McpStatelessServerHandler mcpHandler; + + private McpTransportContextExtractor contextExtractor; + + private volatile boolean isClosing = false; + + private WebFluxStatelessServerTransport(McpJsonMapper jsonMapper, String mcpEndpoint, + McpTransportContextExtractor contextExtractor) { + Assert.notNull(jsonMapper, "jsonMapper must not be null"); + Assert.notNull(mcpEndpoint, "mcpEndpoint must not be null"); + Assert.notNull(contextExtractor, "contextExtractor must not be null"); + + this.jsonMapper = jsonMapper; + this.mcpEndpoint = mcpEndpoint; + this.contextExtractor = contextExtractor; + this.routerFunction = RouterFunctions.route() + .GET(this.mcpEndpoint, this::handleGet) + .POST(this.mcpEndpoint, this::handlePost) + .build(); + } + + @Override + public void setMcpHandler(McpStatelessServerHandler mcpHandler) { + this.mcpHandler = mcpHandler; + } + + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> this.isClosing = true); + } + + /** + * Returns the WebFlux router function that defines the transport's HTTP endpoints. + * This router function should be integrated into the application's web configuration. + * + *

+ * The router function defines one endpoint handling two HTTP methods: + *

    + *
  • GET {messageEndpoint} - Unsupported, returns 405 METHOD NOT ALLOWED
  • + *
  • POST {messageEndpoint} - For handling client requests and notifications
  • + *
+ * @return The configured {@link RouterFunction} for handling HTTP requests + */ + public RouterFunction getRouterFunction() { + return this.routerFunction; + } + + private Mono handleGet(ServerRequest request) { + return ServerResponse.status(HttpStatus.METHOD_NOT_ALLOWED).build(); + } + + private Mono handlePost(ServerRequest request) { + if (isClosing) { + return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down"); + } + + McpTransportContext transportContext = this.contextExtractor.extract(request); + + List acceptHeaders = request.headers().asHttpHeaders().getAccept(); + if (!(acceptHeaders.contains(MediaType.APPLICATION_JSON) + && acceptHeaders.contains(MediaType.TEXT_EVENT_STREAM))) { + return ServerResponse.badRequest().build(); + } + + return request.bodyToMono(String.class).flatMap(body -> { + try { + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(jsonMapper, body); + + if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest) { + return this.mcpHandler.handleRequest(transportContext, jsonrpcRequest).flatMap(jsonrpcResponse -> { + try { + String json = jsonMapper.writeValueAsString(jsonrpcResponse); + return ServerResponse.ok().contentType(MediaType.APPLICATION_JSON).bodyValue(json); + } + catch (IOException e) { + logger.error("Failed to serialize response: {}", e.getMessage()); + return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR) + .bodyValue(new McpError("Failed to serialize response")); + } + }); + } + else if (message instanceof McpSchema.JSONRPCNotification jsonrpcNotification) { + return this.mcpHandler.handleNotification(transportContext, jsonrpcNotification) + .then(ServerResponse.accepted().build()); + } + else { + return ServerResponse.badRequest() + .bodyValue(new McpError("The server accepts either requests or notifications")); + } + } + catch (IllegalArgumentException | IOException e) { + logger.error("Failed to deserialize message: {}", e.getMessage()); + return ServerResponse.badRequest().bodyValue(new McpError("Invalid message format")); + } + }).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)); + } + + /** + * Create a builder for the server. + * @return a fresh {@link Builder} instance. + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for creating instances of {@link WebFluxStatelessServerTransport}. + *

+ * This builder provides a fluent API for configuring and creating instances of + * WebFluxSseServerTransportProvider with custom settings. + */ + public static class Builder { + + private McpJsonMapper jsonMapper; + + private String mcpEndpoint = "/mcp"; + + private McpTransportContextExtractor contextExtractor = ( + serverRequest) -> McpTransportContext.EMPTY; + + private Builder() { + // used by a static method + } + + /** + * Sets the JsonMapper to use for JSON serialization/deserialization of MCP + * messages. + * @param jsonMapper The JsonMapper instance. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if jsonMapper is null + */ + public Builder jsonMapper(McpJsonMapper jsonMapper) { + Assert.notNull(jsonMapper, "JsonMapper must not be null"); + this.jsonMapper = jsonMapper; + return this; + } + + /** + * Sets the endpoint URI where clients should send their JSON-RPC messages. + * @param messageEndpoint The message endpoint URI. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if messageEndpoint is null + */ + public Builder messageEndpoint(String messageEndpoint) { + Assert.notNull(messageEndpoint, "Message endpoint must not be null"); + this.mcpEndpoint = messageEndpoint; + return this; + } + + /** + * Sets the context extractor that allows providing the MCP feature + * implementations to inspect HTTP transport level metadata that was present at + * HTTP request processing time. This allows to extract custom headers and other + * useful data for use during execution later on in the process. + * @param contextExtractor The contextExtractor to fill in a + * {@link McpTransportContext}. + * @return this builder instance + * @throws IllegalArgumentException if contextExtractor is null + */ + public Builder contextExtractor(McpTransportContextExtractor contextExtractor) { + Assert.notNull(contextExtractor, "Context extractor must not be null"); + this.contextExtractor = contextExtractor; + return this; + } + + /** + * Builds a new instance of {@link WebFluxStatelessServerTransport} with the + * configured settings. + * @return A new WebFluxSseServerTransportProvider instance + * @throws IllegalStateException if required parameters are not set + */ + public WebFluxStatelessServerTransport build() { + Assert.notNull(mcpEndpoint, "Message endpoint must be set"); + return new WebFluxStatelessServerTransport(jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, + mcpEndpoint, contextExtractor); + } + + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java new file mode 100644 index 000000000..deebfc616 --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java @@ -0,0 +1,495 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.TypeRef; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.server.McpTransportContextExtractor; +import io.modelcontextprotocol.spec.HttpHeaders; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpStreamableServerSession; +import io.modelcontextprotocol.spec.McpStreamableServerTransport; +import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.ProtocolVersions; +import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.KeepAliveScheduler; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.codec.ServerSentEvent; +import org.springframework.web.reactive.function.server.RouterFunction; +import org.springframework.web.reactive.function.server.RouterFunctions; +import org.springframework.web.reactive.function.server.ServerRequest; +import org.springframework.web.reactive.function.server.ServerResponse; +import reactor.core.Disposable; +import reactor.core.Exceptions; +import reactor.core.publisher.Flux; +import reactor.core.publisher.FluxSink; +import reactor.core.publisher.Mono; + +import java.io.IOException; +import java.time.Duration; +import java.util.List; +import java.util.concurrent.ConcurrentHashMap; + +/** + * Implementation of a WebFlux based {@link McpStreamableServerTransportProvider}. + * + * @author Dariusz Jędrzejczyk + */ +public class WebFluxStreamableServerTransportProvider implements McpStreamableServerTransportProvider { + + private static final Logger logger = LoggerFactory.getLogger(WebFluxStreamableServerTransportProvider.class); + + public static final String MESSAGE_EVENT_TYPE = "message"; + + private final McpJsonMapper jsonMapper; + + private final String mcpEndpoint; + + private final boolean disallowDelete; + + private final RouterFunction routerFunction; + + private McpStreamableServerSession.Factory sessionFactory; + + private final ConcurrentHashMap sessions = new ConcurrentHashMap<>(); + + private McpTransportContextExtractor contextExtractor; + + private volatile boolean isClosing = false; + + private KeepAliveScheduler keepAliveScheduler; + + private WebFluxStreamableServerTransportProvider(McpJsonMapper jsonMapper, String mcpEndpoint, + McpTransportContextExtractor contextExtractor, boolean disallowDelete, + Duration keepAliveInterval) { + Assert.notNull(jsonMapper, "JsonMapper must not be null"); + Assert.notNull(mcpEndpoint, "Message endpoint must not be null"); + Assert.notNull(contextExtractor, "Context extractor must not be null"); + + this.jsonMapper = jsonMapper; + this.mcpEndpoint = mcpEndpoint; + this.contextExtractor = contextExtractor; + this.disallowDelete = disallowDelete; + this.routerFunction = RouterFunctions.route() + .GET(this.mcpEndpoint, this::handleGet) + .POST(this.mcpEndpoint, this::handlePost) + .DELETE(this.mcpEndpoint, this::handleDelete) + .build(); + + if (keepAliveInterval != null) { + this.keepAliveScheduler = KeepAliveScheduler + .builder(() -> (isClosing) ? Flux.empty() : Flux.fromIterable(this.sessions.values())) + .initialDelay(keepAliveInterval) + .interval(keepAliveInterval) + .build(); + + this.keepAliveScheduler.start(); + } + } + + @Override + public List protocolVersions() { + return List.of(ProtocolVersions.MCP_2024_11_05, ProtocolVersions.MCP_2025_03_26, + ProtocolVersions.MCP_2025_06_18); + } + + @Override + public void setSessionFactory(McpStreamableServerSession.Factory sessionFactory) { + this.sessionFactory = sessionFactory; + } + + @Override + public Mono notifyClients(String method, Object params) { + if (sessions.isEmpty()) { + logger.debug("No active sessions to broadcast message to"); + return Mono.empty(); + } + + logger.debug("Attempting to broadcast message to {} active sessions", sessions.size()); + + return Flux.fromIterable(sessions.values()) + .flatMap(session -> session.sendNotification(method, params) + .doOnError( + e -> logger.error("Failed to send message to session {}: {}", session.getId(), e.getMessage())) + .onErrorComplete()) + .then(); + } + + @Override + public Mono closeGracefully() { + return Mono.defer(() -> { + this.isClosing = true; + return Flux.fromIterable(sessions.values()) + .doFirst(() -> logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size())) + .flatMap(McpStreamableServerSession::closeGracefully) + .then(); + }).then().doOnSuccess(v -> { + sessions.clear(); + if (this.keepAliveScheduler != null) { + this.keepAliveScheduler.shutdown(); + } + }); + } + + /** + * Returns the WebFlux router function that defines the transport's HTTP endpoints. + * This router function should be integrated into the application's web configuration. + * + *

+ * The router function defines one endpoint with three methods: + *

    + *
  • GET {messageEndpoint} - For the client listening SSE stream
  • + *
  • POST {messageEndpoint} - For receiving client messages
  • + *
  • DELETE {messageEndpoint} - For removing sessions
  • + *
+ * @return The configured {@link RouterFunction} for handling HTTP requests + */ + public RouterFunction getRouterFunction() { + return this.routerFunction; + } + + /** + * Opens the listening SSE streams for clients. + * @param request The incoming server request + * @return A Mono which emits a response with the SSE event stream + */ + private Mono handleGet(ServerRequest request) { + if (isClosing) { + return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down"); + } + + McpTransportContext transportContext = this.contextExtractor.extract(request); + + return Mono.defer(() -> { + List acceptHeaders = request.headers().asHttpHeaders().getAccept(); + if (!acceptHeaders.contains(MediaType.TEXT_EVENT_STREAM)) { + return ServerResponse.badRequest().build(); + } + + if (request.headers().header(HttpHeaders.MCP_SESSION_ID).isEmpty()) { + return ServerResponse.badRequest().build(); // TODO: say we need a session + // id + } + + String sessionId = request.headers().asHttpHeaders().getFirst(HttpHeaders.MCP_SESSION_ID); + + McpStreamableServerSession session = this.sessions.get(sessionId); + + if (session == null) { + return ServerResponse.notFound().build(); + } + + if (!request.headers().header(HttpHeaders.LAST_EVENT_ID).isEmpty()) { + String lastId = request.headers().asHttpHeaders().getFirst(HttpHeaders.LAST_EVENT_ID); + return ServerResponse.ok() + .contentType(MediaType.TEXT_EVENT_STREAM) + .body(session.replay(lastId) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)), + ServerSentEvent.class); + } + + return ServerResponse.ok() + .contentType(MediaType.TEXT_EVENT_STREAM) + .body(Flux.>create(sink -> { + WebFluxStreamableMcpSessionTransport sessionTransport = new WebFluxStreamableMcpSessionTransport( + sink); + McpStreamableServerSession.McpStreamableServerSessionStream listeningStream = session + .listeningStream(sessionTransport); + sink.onDispose(listeningStream::close); + // TODO Clarify why the outer context is not present in the + // Flux.create sink? + }).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)), ServerSentEvent.class); + + }).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)); + } + + /** + * Handles incoming JSON-RPC messages from clients. + * @param request The incoming server request containing the JSON-RPC message + * @return A Mono with the response appropriate to a particular Streamable HTTP flow. + */ + private Mono handlePost(ServerRequest request) { + if (isClosing) { + return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down"); + } + + McpTransportContext transportContext = this.contextExtractor.extract(request); + + List acceptHeaders = request.headers().asHttpHeaders().getAccept(); + if (!(acceptHeaders.contains(MediaType.APPLICATION_JSON) + && acceptHeaders.contains(MediaType.TEXT_EVENT_STREAM))) { + return ServerResponse.badRequest().build(); + } + + return request.bodyToMono(String.class).flatMap(body -> { + try { + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(jsonMapper, body); + if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest + && jsonrpcRequest.method().equals(McpSchema.METHOD_INITIALIZE)) { + var typeReference = new TypeRef() { + }; + McpSchema.InitializeRequest initializeRequest = jsonMapper.convertValue(jsonrpcRequest.params(), + typeReference); + McpStreamableServerSession.McpStreamableServerSessionInit init = this.sessionFactory + .startSession(initializeRequest); + sessions.put(init.session().getId(), init.session()); + return init.initResult().map(initializeResult -> { + McpSchema.JSONRPCResponse jsonrpcResponse = new McpSchema.JSONRPCResponse( + McpSchema.JSONRPC_VERSION, jsonrpcRequest.id(), initializeResult, null); + try { + return this.jsonMapper.writeValueAsString(jsonrpcResponse); + } + catch (IOException e) { + logger.warn("Failed to serialize initResponse", e); + throw Exceptions.propagate(e); + } + }) + .flatMap(initResult -> ServerResponse.ok() + .contentType(MediaType.APPLICATION_JSON) + .header(HttpHeaders.MCP_SESSION_ID, init.session().getId()) + .bodyValue(initResult)); + } + + if (request.headers().header(HttpHeaders.MCP_SESSION_ID).isEmpty()) { + return ServerResponse.badRequest().bodyValue(new McpError("Session ID missing")); + } + + String sessionId = request.headers().asHttpHeaders().getFirst(HttpHeaders.MCP_SESSION_ID); + McpStreamableServerSession session = sessions.get(sessionId); + + if (session == null) { + return ServerResponse.status(HttpStatus.NOT_FOUND) + .bodyValue(new McpError("Session not found: " + sessionId)); + } + + if (message instanceof McpSchema.JSONRPCResponse jsonrpcResponse) { + return session.accept(jsonrpcResponse).then(ServerResponse.accepted().build()); + } + else if (message instanceof McpSchema.JSONRPCNotification jsonrpcNotification) { + return session.accept(jsonrpcNotification).then(ServerResponse.accepted().build()); + } + else if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest) { + return ServerResponse.ok() + .contentType(MediaType.TEXT_EVENT_STREAM) + .body(Flux.>create(sink -> { + WebFluxStreamableMcpSessionTransport st = new WebFluxStreamableMcpSessionTransport(sink); + Mono stream = session.responseStream(jsonrpcRequest, st); + Disposable streamSubscription = stream.onErrorComplete(err -> { + sink.error(err); + return true; + }).contextWrite(sink.contextView()).subscribe(); + sink.onCancel(streamSubscription); + // TODO Clarify why the outer context is not present in the + // Flux.create sink? + }).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)), + ServerSentEvent.class); + } + else { + return ServerResponse.badRequest().bodyValue(new McpError("Unknown message type")); + } + } + catch (IllegalArgumentException | IOException e) { + logger.error("Failed to deserialize message: {}", e.getMessage()); + return ServerResponse.badRequest().bodyValue(new McpError("Invalid message format")); + } + }) + .switchIfEmpty(ServerResponse.badRequest().build()) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)); + } + + private Mono handleDelete(ServerRequest request) { + if (isClosing) { + return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down"); + } + + McpTransportContext transportContext = this.contextExtractor.extract(request); + + return Mono.defer(() -> { + if (request.headers().header(HttpHeaders.MCP_SESSION_ID).isEmpty()) { + return ServerResponse.badRequest().build(); // TODO: say we need a session + // id + } + + if (this.disallowDelete) { + return ServerResponse.status(HttpStatus.METHOD_NOT_ALLOWED).build(); + } + + String sessionId = request.headers().asHttpHeaders().getFirst(HttpHeaders.MCP_SESSION_ID); + + McpStreamableServerSession session = this.sessions.get(sessionId); + + if (session == null) { + return ServerResponse.notFound().build(); + } + + return session.delete().then(ServerResponse.ok().build()); + }).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)); + } + + private class WebFluxStreamableMcpSessionTransport implements McpStreamableServerTransport { + + private final FluxSink> sink; + + public WebFluxStreamableMcpSessionTransport(FluxSink> sink) { + this.sink = sink; + } + + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message) { + return this.sendMessage(message, null); + } + + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message, String messageId) { + return Mono.fromSupplier(() -> { + try { + return jsonMapper.writeValueAsString(message); + } + catch (IOException e) { + throw Exceptions.propagate(e); + } + }).doOnNext(jsonText -> { + ServerSentEvent event = ServerSentEvent.builder() + .id(messageId) + .event(MESSAGE_EVENT_TYPE) + .data(jsonText) + .build(); + sink.next(event); + }).doOnError(e -> { + // TODO log with sessionid + Throwable exception = Exceptions.unwrap(e); + sink.error(exception); + }).then(); + } + + @Override + public T unmarshalFrom(Object data, TypeRef typeRef) { + return jsonMapper.convertValue(data, typeRef); + } + + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(sink::complete); + } + + @Override + public void close() { + sink.complete(); + } + + } + + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for creating instances of {@link WebFluxStreamableServerTransportProvider}. + *

+ * This builder provides a fluent API for configuring and creating instances of + * WebFluxStreamableServerTransportProvider with custom settings. + */ + public static class Builder { + + private McpJsonMapper jsonMapper; + + private String mcpEndpoint = "/mcp"; + + private McpTransportContextExtractor contextExtractor = ( + serverRequest) -> McpTransportContext.EMPTY; + + private boolean disallowDelete; + + private Duration keepAliveInterval; + + private Builder() { + // used by a static method + } + + /** + * Sets the {@link McpJsonMapper} to use for JSON serialization/deserialization of + * MCP messages. + * @param jsonMapper The {@link McpJsonMapper} instance. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if jsonMapper is null + */ + public Builder jsonMapper(McpJsonMapper jsonMapper) { + Assert.notNull(jsonMapper, "McpJsonMapper must not be null"); + this.jsonMapper = jsonMapper; + return this; + } + + /** + * Sets the endpoint URI where clients should send their JSON-RPC messages. + * @param messageEndpoint The message endpoint URI. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if messageEndpoint is null + */ + public Builder messageEndpoint(String messageEndpoint) { + Assert.notNull(messageEndpoint, "Message endpoint must not be null"); + this.mcpEndpoint = messageEndpoint; + return this; + } + + /** + * Sets the context extractor that allows providing the MCP feature + * implementations to inspect HTTP transport level metadata that was present at + * HTTP request processing time. This allows to extract custom headers and other + * useful data for use during execution later on in the process. + * @param contextExtractor The contextExtractor to fill in a + * {@link McpTransportContext}. + * @return this builder instance + * @throws IllegalArgumentException if contextExtractor is null + */ + public Builder contextExtractor(McpTransportContextExtractor contextExtractor) { + Assert.notNull(contextExtractor, "contextExtractor must not be null"); + this.contextExtractor = contextExtractor; + return this; + } + + /** + * Sets whether the session removal capability is disabled. + * @param disallowDelete if {@code true}, the DELETE endpoint will not be + * supported and sessions won't be deleted. + * @return this builder instance + */ + public Builder disallowDelete(boolean disallowDelete) { + this.disallowDelete = disallowDelete; + return this; + } + + /** + * Sets the keep-alive interval for the server transport. + * @param keepAliveInterval The interval for sending keep-alive messages. If null, + * no keep-alive will be scheduled. + * @return this builder instance + */ + public Builder keepAliveInterval(Duration keepAliveInterval) { + this.keepAliveInterval = keepAliveInterval; + return this; + } + + /** + * Builds a new instance of {@link WebFluxStreamableServerTransportProvider} with + * the configured settings. + * @return A new WebFluxStreamableServerTransportProvider instance + * @throws IllegalStateException if required parameters are not set + */ + public WebFluxStreamableServerTransportProvider build() { + Assert.notNull(mcpEndpoint, "Message endpoint must be set"); + return new WebFluxStreamableServerTransportProvider( + jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, mcpEndpoint, contextExtractor, + disallowDelete, keepAliveInterval); + } + + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java index 2f85654e8..eb8abb90c 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java @@ -1,54 +1,39 @@ /* * Copyright 2024 - 2024 the original author or authors. */ + package io.modelcontextprotocol; import java.time.Duration; -import java.util.List; import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.CopyOnWriteArrayList; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.BiFunction; -import java.util.function.Function; -import java.util.stream.Collectors; +import java.util.stream.Stream; -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.client.McpClient; -import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; -import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; -import io.modelcontextprotocol.server.McpServer; -import io.modelcontextprotocol.server.McpServerFeatures; -import io.modelcontextprotocol.server.TestUtil; -import io.modelcontextprotocol.server.McpSyncServerExchange; -import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider; -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpSchema.*; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; -import reactor.core.publisher.Mono; -import reactor.netty.DisposableServer; -import reactor.netty.http.server.HttpServer; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.provider.Arguments; import org.springframework.http.server.reactive.HttpHandler; import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; -import org.springframework.web.client.RestClient; import org.springframework.web.reactive.function.client.WebClient; import org.springframework.web.reactive.function.server.RouterFunctions; -import reactor.test.StepVerifier; +import org.springframework.web.reactive.function.server.ServerRequest; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatExceptionOfType; -import static org.assertj.core.api.Assertions.assertWith; -import static org.awaitility.Awaitility.await; -import static org.mockito.Mockito.mock; +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.McpServer.AsyncSpecification; +import io.modelcontextprotocol.server.McpServer.SingleSessionSyncSpecification; +import io.modelcontextprotocol.server.McpTransportContextExtractor; +import io.modelcontextprotocol.server.TestUtil; +import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider; +import reactor.netty.DisposableServer; +import reactor.netty.http.server.HttpServer; -class WebFluxSseIntegrationTests { +@Timeout(15) +class WebFluxSseIntegrationTests extends AbstractMcpClientServerIntegrationTests { private static final int PORT = TestUtil.findAvailablePort(); @@ -60,969 +45,62 @@ class WebFluxSseIntegrationTests { private WebFluxSseServerTransportProvider mcpServerTransportProvider; - ConcurrentHashMap clientBuilders = new ConcurrentHashMap<>(); + static McpTransportContextExtractor TEST_CONTEXT_EXTRACTOR = (r) -> McpTransportContext + .create(Map.of("important", "value")); - @BeforeEach - public void before() { + static Stream clientsForTesting() { + return Stream.of(Arguments.of("httpclient"), Arguments.of("webflux")); + } - this.mcpServerTransportProvider = new WebFluxSseServerTransportProvider.Builder() - .objectMapper(new ObjectMapper()) - .messageEndpoint(CUSTOM_MESSAGE_ENDPOINT) - .sseEndpoint(CUSTOM_SSE_ENDPOINT) - .build(); + @Override + protected void prepareClients(int port, String mcpEndpoint) { - HttpHandler httpHandler = RouterFunctions.toHttpHandler(mcpServerTransportProvider.getRouterFunction()); - ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); - this.httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); + clientBuilders + .put("httpclient", + McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT) + .sseEndpoint(CUSTOM_SSE_ENDPOINT) + .build()).requestTimeout(Duration.ofHours(10))); - clientBuilders.put("httpclient", - McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT) - .sseEndpoint(CUSTOM_SSE_ENDPOINT) - .build())); clientBuilders.put("webflux", McpClient .sync(WebFluxSseClientTransport.builder(WebClient.builder().baseUrl("http://localhost:" + PORT)) .sseEndpoint(CUSTOM_SSE_ENDPOINT) - .build())); - - } - - @AfterEach - public void after() { - if (httpServer != null) { - httpServer.disposeNow(); - } - } - - // --------------------------------------- - // Sampling Tests - // --------------------------------------- - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testCreateMessageWithoutSamplingCapabilities(String clientType) { - - var clientBuilder = clientBuilders.get(clientType); - - McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), - (exchange, request) -> exchange.createMessage(mock(CreateMessageRequest.class)) - .thenReturn(mock(CallToolResult.class))); - - var server = McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").tools(tool).build(); - - try (var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) - .build();) { - - assertThat(client.initialize()).isNotNull(); - - try { - client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - } - catch (McpError e) { - assertThat(e).isInstanceOf(McpError.class) - .hasMessage("Client must be configured with sampling capabilities"); - } - } - server.close(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testCreateMessageSuccess(String clientType) { - - var clientBuilder = clientBuilders.get(clientType); - - Function samplingHandler = request -> { - assertThat(request.messages()).hasSize(1); - assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); - - return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", - CreateMessageResult.StopReason.STOP_SEQUENCE); - }; - - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); - - AtomicReference samplingResult = new AtomicReference<>(); - - McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { - - var createMessageRequest = McpSchema.CreateMessageRequest.builder() - .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, - new McpSchema.TextContent("Test message")))) - .modelPreferences(ModelPreferences.builder() - .hints(List.of()) - .costPriority(1.0) - .speedPriority(1.0) - .intelligencePriority(1.0) - .build()) - .build(); - - return exchange.createMessage(createMessageRequest) - .doOnNext(samplingResult::set) - .thenReturn(callResponse); - }); - - var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .tools(tool) - .build(); - - try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().sampling().build()) - .sampling(samplingHandler) - .build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); - - assertWith(samplingResult.get(), result -> { - assertThat(result).isNotNull(); - assertThat(result.role()).isEqualTo(Role.USER); - assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); - assertThat(result.model()).isEqualTo("MockModelName"); - assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); - }); - } - mcpServer.closeGracefully().block(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testCreateMessageWithRequestTimeoutSuccess(String clientType) throws InterruptedException { - - // Client - var clientBuilder = clientBuilders.get(clientType); - - Function samplingHandler = request -> { - assertThat(request.messages()).hasSize(1); - assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); - try { - TimeUnit.SECONDS.sleep(2); - } - catch (InterruptedException e) { - throw new RuntimeException(e); - } - return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", - CreateMessageResult.StopReason.STOP_SEQUENCE); - }; - - // Server - - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); - - AtomicReference samplingResult = new AtomicReference<>(); - - McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { - - var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() - .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, - new McpSchema.TextContent("Test message")))) - .modelPreferences(ModelPreferences.builder() - .hints(List.of()) - .costPriority(1.0) - .speedPriority(1.0) - .intelligencePriority(1.0) - .build()) - .build(); - - return exchange.createMessage(craeteMessageRequest) - .doOnNext(samplingResult::set) - .thenReturn(callResponse); - }); - - var mcpServer = McpServer.async(mcpServerTransportProvider) - .requestTimeout(Duration.ofSeconds(4)) - .serverInfo("test-server", "1.0.0") - .tools(tool) - .build(); - - try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().sampling().build()) - .sampling(samplingHandler) - .build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); - - assertWith(samplingResult.get(), result -> { - assertThat(result).isNotNull(); - assertThat(result.role()).isEqualTo(Role.USER); - assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); - assertThat(result.model()).isEqualTo("MockModelName"); - assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); - }); - } - - mcpServer.closeGracefully().block(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testCreateMessageWithRequestTimeoutFail(String clientType) throws InterruptedException { - - // Client - var clientBuilder = clientBuilders.get(clientType); - - Function samplingHandler = request -> { - assertThat(request.messages()).hasSize(1); - assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); - try { - TimeUnit.SECONDS.sleep(2); - } - catch (InterruptedException e) { - throw new RuntimeException(e); - } - return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", - CreateMessageResult.StopReason.STOP_SEQUENCE); - }; - - // Server - - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); - - McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { - - var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() - .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, - new McpSchema.TextContent("Test message")))) - .build(); - - return exchange.createMessage(craeteMessageRequest).thenReturn(callResponse); - }); - - var mcpServer = McpServer.async(mcpServerTransportProvider) - .requestTimeout(Duration.ofSeconds(1)) - .serverInfo("test-server", "1.0.0") - .tools(tool) - .build(); - - try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().sampling().build()) - .sampling(samplingHandler) - .build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThatExceptionOfType(McpError.class).isThrownBy(() -> { - mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - }).withMessageContaining("within 1000ms"); - - } - - mcpServer.closeGracefully().block(); - } - - // --------------------------------------- - // Elicitation Tests - // --------------------------------------- - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testCreateElicitationWithoutElicitationCapabilities(String clientType) { - - var clientBuilder = clientBuilders.get(clientType); - - McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { - - exchange.createElicitation(mock(ElicitRequest.class)).block(); - - return Mono.just(mock(CallToolResult.class)); - }); - - var server = McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").tools(tool).build(); - - try ( - // Create client without elicitation capabilities - var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")).build()) { - - assertThat(client.initialize()).isNotNull(); - - try { - client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - } - catch (McpError e) { - assertThat(e).isInstanceOf(McpError.class) - .hasMessage("Client must be configured with elicitation capabilities"); - } - } - server.closeGracefully().block(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testCreateElicitationSuccess(String clientType) { - - var clientBuilder = clientBuilders.get(clientType); - - Function elicitationHandler = request -> { - assertThat(request.message()).isNotEmpty(); - assertThat(request.requestedSchema()).isNotNull(); - - return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); - }; - - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); - - McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { - - var elicitationRequest = ElicitRequest.builder() - .message("Test message") - .requestedSchema( - Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) - .build(); - - StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.action()).isEqualTo(ElicitResult.Action.ACCEPT); - assertThat(result.content().get("message")).isEqualTo("Test message"); - }).verifyComplete(); - - return Mono.just(callResponse); - }); - - var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .tools(tool) - .build(); - - try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().elicitation().build()) - .elicitation(elicitationHandler) - .build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); - } - mcpServer.closeGracefully().block(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testCreateElicitationWithRequestTimeoutSuccess(String clientType) { - - // Client - var clientBuilder = clientBuilders.get(clientType); - - Function elicitationHandler = request -> { - assertThat(request.message()).isNotEmpty(); - assertThat(request.requestedSchema()).isNotNull(); - try { - TimeUnit.SECONDS.sleep(2); - } - catch (InterruptedException e) { - throw new RuntimeException(e); - } - return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); - }; - - var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().elicitation().build()) - .elicitation(elicitationHandler) - .build(); - - // Server - - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); - - McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { - - var elicitationRequest = ElicitRequest.builder() - .message("Test message") - .requestedSchema( - Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) - .build(); - - StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.action()).isEqualTo(ElicitResult.Action.ACCEPT); - assertThat(result.content().get("message")).isEqualTo("Test message"); - }).verifyComplete(); - - return Mono.just(callResponse); - }); - - var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .requestTimeout(Duration.ofSeconds(3)) - .tools(tool) - .build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); - - mcpClient.closeGracefully(); - mcpServer.closeGracefully().block(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testCreateElicitationWithRequestTimeoutFail(String clientType) { - - // Client - var clientBuilder = clientBuilders.get(clientType); - - Function elicitationHandler = request -> { - assertThat(request.message()).isNotEmpty(); - assertThat(request.requestedSchema()).isNotNull(); - try { - TimeUnit.SECONDS.sleep(2); - } - catch (InterruptedException e) { - throw new RuntimeException(e); - } - return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); - }; - - var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().elicitation().build()) - .elicitation(elicitationHandler) - .build(); + .build()) + .requestTimeout(Duration.ofHours(10))); - // Server - - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); - - McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { - - var elicitationRequest = ElicitRequest.builder() - .message("Test message") - .requestedSchema( - Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) - .build(); - - StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.action()).isEqualTo(ElicitResult.Action.ACCEPT); - assertThat(result.content().get("message")).isEqualTo("Test message"); - }).verifyComplete(); - - return Mono.just(callResponse); - }); - - var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .requestTimeout(Duration.ofSeconds(1)) - .tools(tool) - .build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThatExceptionOfType(McpError.class).isThrownBy(() -> { - mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - }).withMessageContaining("within 1000ms"); - - mcpClient.closeGracefully(); - mcpServer.closeGracefully().block(); - } - - // --------------------------------------- - // Roots Tests - // --------------------------------------- - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testRootsSuccess(String clientType) { - var clientBuilder = clientBuilders.get(clientType); - - List roots = List.of(new Root("uri1://", "root1"), new Root("uri2://", "root2")); - - AtomicReference> rootsRef = new AtomicReference<>(); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) - .build(); - - try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) - .roots(roots) - .build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThat(rootsRef.get()).isNull(); - - mcpClient.rootsListChangedNotification(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(roots); - }); - - // Remove a root - mcpClient.removeRoot(roots.get(0).uri()); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(roots.get(1))); - }); - - // Add a new root - var root3 = new Root("uri3://", "root3"); - mcpClient.addRoot(root3); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(roots.get(1), root3)); - }); - } - - mcpServer.close(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testRootsWithoutCapability(String clientType) { - - var clientBuilder = clientBuilders.get(clientType); - - McpServerFeatures.SyncToolSpecification tool = new McpServerFeatures.SyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { - - exchange.listRoots(); // try to list roots - - return mock(CallToolResult.class); - }); - - var mcpServer = McpServer.sync(mcpServerTransportProvider).rootsChangeHandler((exchange, rootsUpdate) -> { - }).tools(tool).build(); - - // Create client without roots capability - try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()).build()) { - - assertThat(mcpClient.initialize()).isNotNull(); - - // Attempt to list roots should fail - try { - mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - } - catch (McpError e) { - assertThat(e).isInstanceOf(McpError.class).hasMessage("Roots not supported"); - } - } - - mcpServer.close(); } - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testRootsNotificationWithEmptyRootsList(String clientType) { - var clientBuilder = clientBuilders.get(clientType); - - AtomicReference> rootsRef = new AtomicReference<>(); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) - .build(); - - try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) - .roots(List.of()) // Empty roots list - .build()) { - - assertThat(mcpClient.initialize()).isNotNull(); - - mcpClient.rootsListChangedNotification(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).isEmpty(); - }); - } - - mcpServer.close(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testRootsWithMultipleHandlers(String clientType) { - - var clientBuilder = clientBuilders.get(clientType); - - List roots = List.of(new Root("uri1://", "root1")); - - AtomicReference> rootsRef1 = new AtomicReference<>(); - AtomicReference> rootsRef2 = new AtomicReference<>(); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef1.set(rootsUpdate)) - .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef2.set(rootsUpdate)) - .build(); - - try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) - .roots(roots) - .build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - mcpClient.rootsListChangedNotification(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef1.get()).containsAll(roots); - assertThat(rootsRef2.get()).containsAll(roots); - }); - } - - mcpServer.close(); + @Override + protected AsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(mcpServerTransportProvider); } - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testRootsServerCloseWithActiveSubscription(String clientType) { - - var clientBuilder = clientBuilders.get(clientType); - - List roots = List.of(new Root("uri1://", "root1")); - - AtomicReference> rootsRef = new AtomicReference<>(); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) - .build(); - - try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) - .roots(roots) - .build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - mcpClient.rootsListChangedNotification(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(roots); - }); - } - - mcpServer.close(); + @Override + protected SingleSessionSyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(mcpServerTransportProvider); } - // --------------------------------------- - // Tools Tests - // --------------------------------------- - - String emptyJsonSchema = """ - { - "$schema": "http://json-schema.org/draft-07/schema#", - "type": "object", - "properties": {} - } - """; - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testToolCallSuccess(String clientType) { - - var clientBuilder = clientBuilders.get(clientType); - - var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); - McpServerFeatures.SyncToolSpecification tool1 = new McpServerFeatures.SyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { - // perform a blocking call to a remote service - String response = RestClient.create() - .get() - .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md") - .retrieve() - .body(String.class); - assertThat(response).isNotBlank(); - return callResponse; - }); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool1) - .build(); - - try (var mcpClient = clientBuilder.build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); - - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); - } - - mcpServer.close(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testToolListChangeHandlingSuccess(String clientType) { - - var clientBuilder = clientBuilders.get(clientType); - - var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); - McpServerFeatures.SyncToolSpecification tool1 = new McpServerFeatures.SyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { - // perform a blocking call to a remote service - String response = RestClient.create() - .get() - .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md") - .retrieve() - .body(String.class); - assertThat(response).isNotBlank(); - return callResponse; - }); - - AtomicReference> rootsRef = new AtomicReference<>(); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool1) - .build(); - - try (var mcpClient = clientBuilder.toolsChangeConsumer(toolsUpdate -> { - // perform a blocking call to a remote service - String response = RestClient.create() - .get() - .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md") - .retrieve() - .body(String.class); - assertThat(response).isNotBlank(); - rootsRef.set(toolsUpdate); - }).build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThat(rootsRef.get()).isNull(); - - assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); - - mcpServer.notifyToolsListChanged(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(tool1.tool())); - }); - - // Remove a tool - mcpServer.removeTool("tool1"); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).isEmpty(); - }); - - // Add a new tool - McpServerFeatures.SyncToolSpecification tool2 = new McpServerFeatures.SyncToolSpecification( - new McpSchema.Tool("tool2", "tool2 description", emptyJsonSchema), - (exchange, request) -> callResponse); - - mcpServer.addTool(tool2); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(tool2.tool())); - }); - } - - mcpServer.close(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testInitialize(String clientType) { - - var clientBuilder = clientBuilders.get(clientType); - - var mcpServer = McpServer.sync(mcpServerTransportProvider).build(); - - try (var mcpClient = clientBuilder.build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - } - - mcpServer.close(); - } - - // --------------------------------------- - // Logging Tests - // --------------------------------------- - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testLoggingNotification(String clientType) throws InterruptedException { - int expectedNotificationsCount = 3; - CountDownLatch latch = new CountDownLatch(expectedNotificationsCount); - // Create a list to store received logging notifications - List receivedNotifications = new CopyOnWriteArrayList<>(); - - var clientBuilder = clientBuilders.get(clientType); - - // Create server with a tool that sends logging notifications - McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( - new McpSchema.Tool("logging-test", "Test logging notifications", emptyJsonSchema), - (exchange, request) -> { - - // Create and send notifications with different levels - - //@formatter:off - return exchange // This should be filtered out (DEBUG < NOTICE) - .loggingNotification(McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.DEBUG) - .logger("test-logger") - .data("Debug message") - .build()) - .then(exchange // This should be sent (NOTICE >= NOTICE) - .loggingNotification(McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.NOTICE) - .logger("test-logger") - .data("Notice message") - .build())) - .then(exchange // This should be sent (ERROR > NOTICE) - .loggingNotification(McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.ERROR) - .logger("test-logger") - .data("Error message") - .build())) - .then(exchange // This should be filtered out (INFO < NOTICE) - .loggingNotification(McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.INFO) - .logger("test-logger") - .data("Another info message") - .build())) - .then(exchange // This should be sent (ERROR >= NOTICE) - .loggingNotification(McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.ERROR) - .logger("test-logger") - .data("Another error message") - .build())) - .thenReturn(new CallToolResult("Logging test completed", false)); - //@formatter:on - }); + @BeforeEach + public void before() { - var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().logging().tools(true).build()) - .tools(tool) + this.mcpServerTransportProvider = new WebFluxSseServerTransportProvider.Builder() + .messageEndpoint(CUSTOM_MESSAGE_ENDPOINT) + .sseEndpoint(CUSTOM_SSE_ENDPOINT) + .contextExtractor(TEST_CONTEXT_EXTRACTOR) .build(); - try ( - // Create client with logging notification handler - var mcpClient = clientBuilder.loggingConsumer(notification -> { - receivedNotifications.add(notification); - latch.countDown(); - }).build()) { - - // Initialize client - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - // Set minimum logging level to NOTICE - mcpClient.setLoggingLevel(McpSchema.LoggingLevel.NOTICE); - - // Call the tool that sends logging notifications - CallToolResult result = mcpClient.callTool(new McpSchema.CallToolRequest("logging-test", Map.of())); - assertThat(result).isNotNull(); - assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Logging test completed"); - - assertThat(latch.await(5, TimeUnit.SECONDS)).as("Should receive notifications in reasonable time").isTrue(); - - // Should have received 3 notifications (1 NOTICE and 2 ERROR) - assertThat(receivedNotifications).hasSize(expectedNotificationsCount); - - Map notificationMap = receivedNotifications.stream() - .collect(Collectors.toMap(n -> n.data(), n -> n)); - - // First notification should be NOTICE level - assertThat(notificationMap.get("Notice message").level()).isEqualTo(McpSchema.LoggingLevel.NOTICE); - assertThat(notificationMap.get("Notice message").logger()).isEqualTo("test-logger"); - assertThat(notificationMap.get("Notice message").data()).isEqualTo("Notice message"); - - // Second notification should be ERROR level - assertThat(notificationMap.get("Error message").level()).isEqualTo(McpSchema.LoggingLevel.ERROR); - assertThat(notificationMap.get("Error message").logger()).isEqualTo("test-logger"); - assertThat(notificationMap.get("Error message").data()).isEqualTo("Error message"); + HttpHandler httpHandler = RouterFunctions.toHttpHandler(mcpServerTransportProvider.getRouterFunction()); + ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); + this.httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); - // Third notification should be ERROR level - assertThat(notificationMap.get("Another error message").level()).isEqualTo(McpSchema.LoggingLevel.ERROR); - assertThat(notificationMap.get("Another error message").logger()).isEqualTo("test-logger"); - assertThat(notificationMap.get("Another error message").data()).isEqualTo("Another error message"); - } - mcpServer.close(); + prepareClients(PORT, null); } - // --------------------------------------- - // Completion Tests - // --------------------------------------- - @ParameterizedTest(name = "{0} : Completion call") - @ValueSource(strings = { "httpclient", "webflux" }) - void testCompletionShouldReturnExpectedSuggestions(String clientType) { - var clientBuilder = clientBuilders.get(clientType); - - var expectedValues = List.of("python", "pytorch", "pyside"); - var completionResponse = new McpSchema.CompleteResult(new CompleteResult.CompleteCompletion(expectedValues, 10, // total - true // hasMore - )); - - AtomicReference samplingRequest = new AtomicReference<>(); - BiFunction completionHandler = (mcpSyncServerExchange, - request) -> { - samplingRequest.set(request); - return completionResponse; - }; - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .capabilities(ServerCapabilities.builder().completions().build()) - .prompts(new McpServerFeatures.SyncPromptSpecification( - new Prompt("code_review", "this is code review prompt", - List.of(new PromptArgument("language", "string", false))), - (mcpSyncServerExchange, getPromptRequest) -> null)) - .completions(new McpServerFeatures.SyncCompletionSpecification( - new McpSchema.PromptReference("ref/prompt", "code_review"), completionHandler)) - .build(); - - try (var mcpClient = clientBuilder.build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - CompleteRequest request = new CompleteRequest(new PromptReference("ref/prompt", "code_review"), - new CompleteRequest.CompleteArgument("language", "py")); - - CompleteResult result = mcpClient.completeCompletion(request); - - assertThat(result).isNotNull(); - - assertThat(samplingRequest.get().argument().name()).isEqualTo("language"); - assertThat(samplingRequest.get().argument().value()).isEqualTo("py"); - assertThat(samplingRequest.get().ref().type()).isEqualTo("ref/prompt"); + @AfterEach + public void after() { + if (httpServer != null) { + httpServer.disposeNow(); } - - mcpServer.close(); } -} \ No newline at end of file +} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStatelessIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStatelessIntegrationTests.java new file mode 100644 index 000000000..96a786a9e --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStatelessIntegrationTests.java @@ -0,0 +1,91 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ + +package io.modelcontextprotocol; + +import java.time.Duration; +import java.util.stream.Stream; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.provider.Arguments; + +import org.springframework.http.server.reactive.HttpHandler; +import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; +import org.springframework.web.reactive.function.client.WebClient; +import org.springframework.web.reactive.function.server.RouterFunctions; +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.McpServer.StatelessAsyncSpecification; +import io.modelcontextprotocol.server.McpServer.StatelessSyncSpecification; +import io.modelcontextprotocol.server.TestUtil; +import io.modelcontextprotocol.server.transport.WebFluxStatelessServerTransport; +import reactor.netty.DisposableServer; +import reactor.netty.http.server.HttpServer; + +@Timeout(15) +class WebFluxStatelessIntegrationTests extends AbstractStatelessIntegrationTests { + + private static final int PORT = TestUtil.findAvailablePort(); + + private static final String CUSTOM_MESSAGE_ENDPOINT = "/otherPath/mcp/message"; + + private DisposableServer httpServer; + + private WebFluxStatelessServerTransport mcpStreamableServerTransport; + + static Stream clientsForTesting() { + return Stream.of(Arguments.of("httpclient"), Arguments.of("webflux")); + } + + @Override + protected void prepareClients(int port, String mcpEndpoint) { + clientBuilders + .put("httpclient", + McpClient.sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT) + .endpoint(CUSTOM_MESSAGE_ENDPOINT) + .build()).initializationTimeout(Duration.ofHours(10)).requestTimeout(Duration.ofHours(10))); + clientBuilders + .put("webflux", McpClient + .sync(WebClientStreamableHttpTransport.builder(WebClient.builder().baseUrl("http://localhost:" + PORT)) + .endpoint(CUSTOM_MESSAGE_ENDPOINT) + .build()) + .initializationTimeout(Duration.ofHours(10)) + .requestTimeout(Duration.ofHours(10))); + } + + @Override + protected StatelessAsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(this.mcpStreamableServerTransport); + } + + @Override + protected StatelessSyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(this.mcpStreamableServerTransport); + } + + @BeforeEach + public void before() { + this.mcpStreamableServerTransport = WebFluxStatelessServerTransport.builder() + .messageEndpoint(CUSTOM_MESSAGE_ENDPOINT) + .build(); + + HttpHandler httpHandler = RouterFunctions.toHttpHandler(mcpStreamableServerTransport.getRouterFunction()); + ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); + this.httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); + + prepareClients(PORT, null); + } + + @AfterEach + public void after() { + if (httpServer != null) { + httpServer.disposeNow(); + } + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStreamableHttpVersionNegotiationIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStreamableHttpVersionNegotiationIntegrationTests.java new file mode 100644 index 000000000..5d2bfda68 --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStreamableHttpVersionNegotiationIntegrationTests.java @@ -0,0 +1,151 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package io.modelcontextprotocol; + +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.function.BiFunction; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.McpServerFeatures; +import io.modelcontextprotocol.server.McpSyncServer; +import io.modelcontextprotocol.server.McpSyncServerExchange; +import io.modelcontextprotocol.server.TestUtil; +import io.modelcontextprotocol.server.transport.WebFluxStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.ProtocolVersions; +import io.modelcontextprotocol.utils.McpTestRequestRecordingExchangeFilterFunction; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import reactor.netty.DisposableServer; +import reactor.netty.http.server.HttpServer; + +import org.springframework.http.HttpMethod; +import org.springframework.http.server.reactive.HttpHandler; +import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; +import org.springframework.web.reactive.function.client.WebClient; +import org.springframework.web.reactive.function.server.RouterFunction; +import org.springframework.web.reactive.function.server.RouterFunctions; +import org.springframework.web.reactive.function.server.ServerResponse; + +import static org.assertj.core.api.Assertions.assertThat; + +class WebFluxStreamableHttpVersionNegotiationIntegrationTests { + + private static final int PORT = TestUtil.findAvailablePort(); + + private DisposableServer httpServer; + + private final McpTestRequestRecordingExchangeFilterFunction recordingFilterFunction = new McpTestRequestRecordingExchangeFilterFunction(); + + private final McpSchema.Tool toolSpec = McpSchema.Tool.builder() + .name("test-tool") + .description("return the protocol version used") + .build(); + + private final BiFunction toolHandler = ( + exchange, request) -> new McpSchema.CallToolResult( + exchange.transportContext().get("protocol-version").toString(), null); + + private final WebFluxStreamableServerTransportProvider mcpStreamableServerTransportProvider = WebFluxStreamableServerTransportProvider + .builder() + .contextExtractor(req -> McpTransportContext + .create(Map.of("protocol-version", req.headers().firstHeader("MCP-protocol-version")))) + .build(); + + private final McpSyncServer mcpServer = McpServer.sync(mcpStreamableServerTransportProvider) + .capabilities(McpSchema.ServerCapabilities.builder().tools(false).build()) + .tools(new McpServerFeatures.SyncToolSpecification(toolSpec, null, toolHandler)) + .build(); + + @BeforeEach + void setUp() { + RouterFunction filteredRouter = mcpStreamableServerTransportProvider.getRouterFunction() + .filter(recordingFilterFunction); + + HttpHandler httpHandler = RouterFunctions.toHttpHandler(filteredRouter); + + ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); + + this.httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); + } + + @AfterEach + public void after() { + if (httpServer != null) { + httpServer.disposeNow(); + } + if (mcpServer != null) { + mcpServer.close(); + } + } + + @Test + void usesLatestVersion() { + var client = McpClient + .sync(WebClientStreamableHttpTransport.builder(WebClient.builder().baseUrl("http://localhost:" + PORT)) + .build()) + .requestTimeout(Duration.ofHours(10)) + .build(); + + client.initialize(); + + McpSchema.CallToolResult response = client.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())); + + var calls = recordingFilterFunction.getCalls(); + assertThat(calls).filteredOn(c -> !c.body().contains("\"method\":\"initialize\"")) + // GET /mcp ; POST notification/initialized ; POST tools/call + .hasSize(3) + .map(McpTestRequestRecordingExchangeFilterFunction.Call::headers) + .allSatisfy(headers -> assertThat(headers).containsEntry("mcp-protocol-version", + ProtocolVersions.MCP_2025_06_18)); + + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo(ProtocolVersions.MCP_2025_06_18); + mcpServer.close(); + } + + @Test + void usesServerSupportedVersion() { + var transport = WebClientStreamableHttpTransport + .builder(WebClient.builder().baseUrl("http://localhost:" + PORT)) + .supportedProtocolVersions(List.of(ProtocolVersions.MCP_2025_06_18, "2263-03-18")) + .build(); + var client = McpClient.sync(transport).requestTimeout(Duration.ofHours(10)).build(); + + client.initialize(); + + McpSchema.CallToolResult response = client.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())); + + var calls = recordingFilterFunction.getCalls(); + // Initialize tells the server the Client's latest supported version + // FIXME: Set the correct protocol version on GET /mcp + assertThat(calls) + .filteredOn(c -> !c.body().contains("\"method\":\"initialize\"") && c.method().equals(HttpMethod.POST)) + // POST notification/initialized ; POST tools/call + .hasSize(2) + .map(McpTestRequestRecordingExchangeFilterFunction.Call::headers) + .allSatisfy(headers -> assertThat(headers).containsEntry("mcp-protocol-version", + ProtocolVersions.MCP_2025_06_18)); + + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo(ProtocolVersions.MCP_2025_06_18); + mcpServer.close(); + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStreamableIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStreamableIntegrationTests.java new file mode 100644 index 000000000..5ab651931 --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStreamableIntegrationTests.java @@ -0,0 +1,103 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ + +package io.modelcontextprotocol; + +import java.time.Duration; +import java.util.Map; +import java.util.stream.Stream; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.provider.Arguments; + +import org.springframework.http.server.reactive.HttpHandler; +import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; +import org.springframework.web.reactive.function.client.WebClient; +import org.springframework.web.reactive.function.server.RouterFunctions; +import org.springframework.web.reactive.function.server.ServerRequest; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.McpServer.AsyncSpecification; +import io.modelcontextprotocol.server.McpServer.SyncSpecification; +import io.modelcontextprotocol.server.McpTransportContextExtractor; +import io.modelcontextprotocol.server.TestUtil; +import io.modelcontextprotocol.server.transport.WebFluxStreamableServerTransportProvider; +import reactor.netty.DisposableServer; +import reactor.netty.http.server.HttpServer; + +@Timeout(15) +class WebFluxStreamableIntegrationTests extends AbstractMcpClientServerIntegrationTests { + + private static final int PORT = TestUtil.findAvailablePort(); + + private static final String CUSTOM_MESSAGE_ENDPOINT = "/otherPath/mcp/message"; + + private DisposableServer httpServer; + + private WebFluxStreamableServerTransportProvider mcpStreamableServerTransportProvider; + + static McpTransportContextExtractor TEST_CONTEXT_EXTRACTOR = (r) -> McpTransportContext + .create(Map.of("important", "value")); + + static Stream clientsForTesting() { + return Stream.of(Arguments.of("httpclient"), Arguments.of("webflux")); + } + + @Override + protected void prepareClients(int port, String mcpEndpoint) { + + clientBuilders + .put("httpclient", + McpClient.sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT) + .endpoint(CUSTOM_MESSAGE_ENDPOINT) + .build()).requestTimeout(Duration.ofHours(10))); + clientBuilders.put("webflux", + McpClient + .sync(WebClientStreamableHttpTransport + .builder(WebClient.builder().baseUrl("http://localhost:" + PORT)) + .endpoint(CUSTOM_MESSAGE_ENDPOINT) + .build()) + .requestTimeout(Duration.ofHours(10))); + } + + @Override + protected AsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(mcpStreamableServerTransportProvider); + } + + @Override + protected SyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(mcpStreamableServerTransportProvider); + } + + @BeforeEach + public void before() { + + this.mcpStreamableServerTransportProvider = WebFluxStreamableServerTransportProvider.builder() + .messageEndpoint(CUSTOM_MESSAGE_ENDPOINT) + .contextExtractor(TEST_CONTEXT_EXTRACTOR) + .build(); + + HttpHandler httpHandler = RouterFunctions + .toHttpHandler(mcpStreamableServerTransportProvider.getRouterFunction()); + ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); + this.httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); + + prepareClients(PORT, null); + } + + @AfterEach + public void after() { + if (httpServer != null) { + httpServer.disposeNow(); + } + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientResiliencyTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientResiliencyTests.java index 80fc671e2..191f10376 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientResiliencyTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientResiliencyTests.java @@ -1,6 +1,9 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + package io.modelcontextprotocol.client; -import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; import io.modelcontextprotocol.spec.McpClientTransport; import org.junit.jupiter.api.Timeout; diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientTests.java index f824193fd..cf4458506 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientTests.java @@ -1,22 +1,27 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + package io.modelcontextprotocol.client; -import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; -import io.modelcontextprotocol.spec.McpClientTransport; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Timeout; +import org.springframework.web.reactive.function.client.WebClient; import org.testcontainers.containers.GenericContainer; import org.testcontainers.containers.wait.strategy.Wait; -import org.springframework.web.reactive.function.client.WebClient; +import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; @Timeout(15) public class WebClientStreamableHttpAsyncClientTests extends AbstractMcpAsyncClientTests { static String host = "http://localhost:3001"; - // Uses the https://github.com/tzolov/mcp-everything-server-docker-image @SuppressWarnings("resource") - GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") - .withCommand("node dist/index.js streamableHttp") + static GenericContainer container = new GenericContainer<>("docker.io/node:lts-alpine3.23") + .withCommand("npx -y @modelcontextprotocol/server-everything@2025.12.18 streamableHttp") .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) .withExposedPorts(3001) .waitingFor(Wait.forHttp("/").forStatusCode(404)); @@ -26,15 +31,15 @@ protected McpClientTransport createMcpTransport() { return WebClientStreamableHttpTransport.builder(WebClient.builder().baseUrl(host)).build(); } - @Override - protected void onStart() { + @BeforeAll + static void startContainer() { container.start(); int port = container.getMappedPort(3001); host = "http://" + container.getHost() + ":" + port; } - @Override - public void onClose() { + @AfterAll + static void stopContainer() { container.stop(); } diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpSyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpSyncClientTests.java index 9ecd8a7d1..f47ba5277 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpSyncClientTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpSyncClientTests.java @@ -1,22 +1,27 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + package io.modelcontextprotocol.client; -import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; -import io.modelcontextprotocol.spec.McpClientTransport; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Timeout; +import org.springframework.web.reactive.function.client.WebClient; import org.testcontainers.containers.GenericContainer; import org.testcontainers.containers.wait.strategy.Wait; -import org.springframework.web.reactive.function.client.WebClient; +import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; @Timeout(15) public class WebClientStreamableHttpSyncClientTests extends AbstractMcpSyncClientTests { static String host = "http://localhost:3001"; - // Uses the https://github.com/tzolov/mcp-everything-server-docker-image @SuppressWarnings("resource") - GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") - .withCommand("node dist/index.js streamableHttp") + static GenericContainer container = new GenericContainer<>("docker.io/node:lts-alpine3.23") + .withCommand("npx -y @modelcontextprotocol/server-everything@2025.12.18 streamableHttp") .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) .withExposedPorts(3001) .waitingFor(Wait.forHttp("/").forStatusCode(404)); @@ -26,15 +31,15 @@ protected McpClientTransport createMcpTransport() { return WebClientStreamableHttpTransport.builder(WebClient.builder().baseUrl(host)).build(); } - @Override - protected void onStart() { + @BeforeAll + static void startContainer() { container.start(); int port = container.getMappedPort(3001); host = "http://" + container.getHost() + ":" + port; } - @Override - public void onClose() { + @AfterAll + static void stopContainer() { container.stop(); } diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java index 0edf4cd54..72c0168d5 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java @@ -6,6 +6,8 @@ import java.time.Duration; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Timeout; import org.springframework.web.reactive.function.client.WebClient; import org.testcontainers.containers.GenericContainer; @@ -24,28 +26,27 @@ class WebFluxSseMcpAsyncClientTests extends AbstractMcpAsyncClientTests { static String host = "http://localhost:3001"; - // Uses the https://github.com/tzolov/mcp-everything-server-docker-image @SuppressWarnings("resource") - GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") - .withCommand("node dist/index.js sse") + static GenericContainer container = new GenericContainer<>("docker.io/node:lts-alpine3.23") + .withCommand("npx -y @modelcontextprotocol/server-everything@2025.12.18 sse") .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) .withExposedPorts(3001) - .waitingFor(Wait.forHttp("/").forStatusCode(404)); + .waitingFor(Wait.forHttp("/").forStatusCode(404).forPort(3001)); @Override protected McpClientTransport createMcpTransport() { return WebFluxSseClientTransport.builder(WebClient.builder().baseUrl(host)).build(); } - @Override - protected void onStart() { + @BeforeAll + static void startContainer() { container.start(); int port = container.getMappedPort(3001); host = "http://" + container.getHost() + ":" + port; } - @Override - public void onClose() { + @AfterAll + static void stopContainer() { container.stop(); } diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java index 9b0959a35..b483029e0 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java @@ -8,10 +8,11 @@ import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; import io.modelcontextprotocol.spec.McpClientTransport; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Timeout; import org.testcontainers.containers.GenericContainer; import org.testcontainers.containers.wait.strategy.Wait; - import org.springframework.web.reactive.function.client.WebClient; /** @@ -24,10 +25,9 @@ class WebFluxSseMcpSyncClientTests extends AbstractMcpSyncClientTests { static String host = "http://localhost:3001"; - // Uses the https://github.com/tzolov/mcp-everything-server-docker-image @SuppressWarnings("resource") - GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") - .withCommand("node dist/index.js sse") + static GenericContainer container = new GenericContainer<>("docker.io/node:lts-alpine3.23") + .withCommand("npx -y @modelcontextprotocol/server-everything@2025.12.18 sse") .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) .withExposedPorts(3001) .waitingFor(Wait.forHttp("/").forStatusCode(404)); @@ -37,15 +37,15 @@ protected McpClientTransport createMcpTransport() { return WebFluxSseClientTransport.builder(WebClient.builder().baseUrl(host)).build(); } - @Override - protected void onStart() { + @BeforeAll + static void startContainer() { container.start(); int port = container.getMappedPort(3001); host = "http://" + container.getHost() + ":" + port; } - @Override - protected void onClose() { + @AfterAll + static void stopContainer() { container.stop(); } diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransportErrorHandlingTest.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransportErrorHandlingTest.java new file mode 100644 index 000000000..214fa489b --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransportErrorHandlingTest.java @@ -0,0 +1,403 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package io.modelcontextprotocol.client.transport; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.timeout; +import static org.mockito.Mockito.verify; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.time.Duration; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.springframework.web.reactive.function.client.WebClient; + +import com.sun.net.httpserver.HttpServer; + +import io.modelcontextprotocol.server.TestUtil; +import io.modelcontextprotocol.spec.HttpHeaders; +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpTransportException; +import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException; +import io.modelcontextprotocol.spec.ProtocolVersions; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +/** + * Tests for error handling in WebClientStreamableHttpTransport. Addresses concurrency + * issues with proper Reactor patterns. + * + * @author Christian Tzolov + */ +@Timeout(15) +public class WebClientStreamableHttpTransportErrorHandlingTest { + + private static final int PORT = TestUtil.findAvailablePort(); + + private static final String HOST = "http://localhost:" + PORT; + + private HttpServer server; + + private AtomicReference serverResponseStatus = new AtomicReference<>(200); + + private AtomicReference currentServerSessionId = new AtomicReference<>(null); + + private AtomicReference lastReceivedSessionId = new AtomicReference<>(null); + + private McpClientTransport transport; + + // Initialize latches for proper request synchronization + CountDownLatch firstRequestLatch; + + CountDownLatch secondRequestLatch; + + CountDownLatch getRequestLatch; + + @BeforeEach + void startServer() throws IOException { + + // Initialize latches for proper synchronization + firstRequestLatch = new CountDownLatch(1); + secondRequestLatch = new CountDownLatch(1); + getRequestLatch = new CountDownLatch(1); + + server = HttpServer.create(new InetSocketAddress(PORT), 0); + + // Configure the /mcp endpoint with dynamic response + server.createContext("/mcp", exchange -> { + String method = exchange.getRequestMethod(); + + if ("GET".equals(method)) { + // This is the SSE connection attempt after session establishment + getRequestLatch.countDown(); + // Return 405 Method Not Allowed to indicate SSE not supported + exchange.sendResponseHeaders(405, 0); + exchange.close(); + return; + } + + String requestSessionId = exchange.getRequestHeaders().getFirst(HttpHeaders.MCP_SESSION_ID); + lastReceivedSessionId.set(requestSessionId); + + int status = serverResponseStatus.get(); + + // Track which request this is + if (firstRequestLatch.getCount() > 0) { + // // First request - should have no session ID + firstRequestLatch.countDown(); + } + else if (secondRequestLatch.getCount() > 0) { + // Second request - should have session ID + secondRequestLatch.countDown(); + } + + exchange.getResponseHeaders().set("Content-Type", "application/json"); + + // Don't include session ID in 404 and 400 responses - the implementation + // checks if the transport has a session stored locally + String responseSessionId = currentServerSessionId.get(); + if (responseSessionId != null && status == 200) { + exchange.getResponseHeaders().set(HttpHeaders.MCP_SESSION_ID, responseSessionId); + } + if (status == 200) { + String response = "{\"jsonrpc\":\"2.0\",\"result\":{},\"id\":\"test-id\"}"; + exchange.sendResponseHeaders(200, response.length()); + exchange.getResponseBody().write(response.getBytes()); + } + else { + exchange.sendResponseHeaders(status, 0); + } + exchange.close(); + }); + + server.setExecutor(null); + server.start(); + + transport = WebClientStreamableHttpTransport.builder(WebClient.builder().baseUrl(HOST)).build(); + } + + @AfterEach + void stopServer() { + if (server != null) { + server.stop(0); + } + StepVerifier.create(transport.closeGracefully()).verifyComplete(); + } + + /** + * Test that 404 response WITHOUT session ID throws McpTransportException (not + * SessionNotFoundException) + */ + @Test + void test404WithoutSessionId() { + serverResponseStatus.set(404); + currentServerSessionId.set(null); // No session ID in response + + var testMessage = createTestMessage(); + + StepVerifier.create(transport.sendMessage(testMessage)) + .expectErrorMatches(throwable -> throwable instanceof McpTransportException + && throwable.getMessage().contains("Not Found") && throwable.getMessage().contains("404") + && !(throwable instanceof McpTransportSessionNotFoundException)) + .verify(Duration.ofSeconds(5)); + } + + /** + * Test that 404 response WITH session ID throws McpTransportSessionNotFoundException + * Fixed version using proper async coordination + */ + @Test + void test404WithSessionId() throws InterruptedException { + // First establish a session + serverResponseStatus.set(200); + currentServerSessionId.set("test-session-123"); + + // Set up exception handler to verify session invalidation + @SuppressWarnings("unchecked") + Consumer exceptionHandler = mock(Consumer.class); + transport.setExceptionHandler(exceptionHandler); + + // Connect with handler + StepVerifier.create(transport.connect(msg -> msg)).verifyComplete(); + + // Send initial message to establish session + var testMessage = createTestMessage(); + + // Send first message to establish session + StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); + + // Wait for first request to complete + assertThat(firstRequestLatch.await(5, TimeUnit.SECONDS)).isTrue(); + + // Wait for the GET request (SSE connection attempt) to complete + assertThat(getRequestLatch.await(5, TimeUnit.SECONDS)).isTrue(); + + // Now return 404 for next request + serverResponseStatus.set(404); + + // Use delaySubscription to ensure session is fully processed before next + // request + StepVerifier.create(Mono.delay(Duration.ofMillis(200)).then(transport.sendMessage(testMessage))) + .expectError(McpTransportSessionNotFoundException.class) + .verify(Duration.ofSeconds(5)); + + // Wait for second request to be made + assertThat(secondRequestLatch.await(5, TimeUnit.SECONDS)).isTrue(); + + // Verify the second request included the session ID + assertThat(lastReceivedSessionId.get()).isEqualTo("test-session-123"); + + // Verify exception handler was called with SessionNotFoundException using + // timeout + verify(exceptionHandler, timeout(5000)).accept(any(McpTransportSessionNotFoundException.class)); + } + + /** + * Test that 400 response WITHOUT session ID throws McpTransportException (not + * SessionNotFoundException) + */ + @Test + void test400WithoutSessionId() { + serverResponseStatus.set(400); + currentServerSessionId.set(null); // No session ID + + var testMessage = createTestMessage(); + + StepVerifier.create(transport.sendMessage(testMessage)) + .expectErrorMatches(throwable -> throwable instanceof McpTransportException + && throwable.getMessage().contains("Bad Request") && throwable.getMessage().contains("400") + && !(throwable instanceof McpTransportSessionNotFoundException)) + .verify(Duration.ofSeconds(5)); + } + + /** + * Test that 400 response WITH session ID throws McpTransportSessionNotFoundException + * Fixed version using proper async coordination + */ + @Test + void test400WithSessionId() throws InterruptedException { + + // First establish a session + serverResponseStatus.set(200); + currentServerSessionId.set("test-session-456"); + + // Set up exception handler + @SuppressWarnings("unchecked") + Consumer exceptionHandler = mock(Consumer.class); + transport.setExceptionHandler(exceptionHandler); + + // Connect with handler + StepVerifier.create(transport.connect(msg -> msg)).verifyComplete(); + + // Send initial message to establish session + var testMessage = createTestMessage(); + + // Send first message to establish session + StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); + + // Wait for first request to complete + boolean firstCompleted = firstRequestLatch.await(5, TimeUnit.SECONDS); + assertThat(firstCompleted).isTrue(); + + // Wait for the GET request (SSE connection attempt) to complete + boolean getCompleted = getRequestLatch.await(5, TimeUnit.SECONDS); + assertThat(getCompleted).isTrue(); + + // Now return 400 for next request (simulating unknown session ID) + serverResponseStatus.set(400); + + // Use delaySubscription to ensure session is fully processed before next + // request + StepVerifier.create(Mono.delay(Duration.ofMillis(200)).then(transport.sendMessage(testMessage))) + .expectError(McpTransportSessionNotFoundException.class) + .verify(Duration.ofSeconds(5)); + + // Wait for second request to be made + boolean secondCompleted = secondRequestLatch.await(5, TimeUnit.SECONDS); + assertThat(secondCompleted).isTrue(); + + // Verify the second request included the session ID + assertThat(lastReceivedSessionId.get()).isEqualTo("test-session-456"); + + // Verify exception handler was called with timeout + verify(exceptionHandler, timeout(5000)).accept(any(McpTransportSessionNotFoundException.class)); + } + + /** + * Test session recovery after SessionNotFoundException Fixed version using reactive + * patterns and proper synchronization + */ + @Test + void testSessionRecoveryAfter404() { + // First establish a session + serverResponseStatus.set(200); + currentServerSessionId.set("session-1"); + + // Send initial message to establish session + var testMessage = createTestMessage(); + + // Use Mono.defer to ensure proper sequencing + Mono establishSession = transport.sendMessage(testMessage).then(Mono.defer(() -> { + // Simulate session loss - return 404 + serverResponseStatus.set(404); + return transport.sendMessage(testMessage).onErrorResume(McpTransportSessionNotFoundException.class, e -> { + // Expected error, continue with recovery + return Mono.empty(); + }); + })).then(Mono.defer(() -> { + // Now server is back with new session + serverResponseStatus.set(200); + currentServerSessionId.set("session-2"); + lastReceivedSessionId.set(null); // Reset to verify new session + + // Should be able to establish new session + return transport.sendMessage(testMessage); + })).then(Mono.defer(() -> { + // Verify no session ID was sent (since old session was invalidated) + assertThat(lastReceivedSessionId.get()).isNull(); + + // Next request should use the new session ID + return transport.sendMessage(testMessage); + })).doOnSuccess(v -> { + // Session ID should now be sent with requests + assertThat(lastReceivedSessionId.get()).isEqualTo("session-2"); + }); + + StepVerifier.create(establishSession).verifyComplete(); + } + + /** + * Test that reconnect (GET request) also properly handles 404/400 errors Fixed + * version with proper async handling + */ + @Test + void testReconnectErrorHandling() throws InterruptedException { + // Initialize latch for SSE connection + CountDownLatch sseConnectionLatch = new CountDownLatch(1); + + // Set up SSE endpoint for GET requests + server.createContext("/mcp-sse", exchange -> { + String method = exchange.getRequestMethod(); + String requestSessionId = exchange.getRequestHeaders().getFirst(HttpHeaders.MCP_SESSION_ID); + + if ("GET".equals(method)) { + sseConnectionLatch.countDown(); + int status = serverResponseStatus.get(); + + if (status == 404 && requestSessionId != null) { + // 404 with session ID - should trigger SessionNotFoundException + exchange.sendResponseHeaders(404, 0); + } + else if (status == 404) { + // 404 without session ID - should trigger McpTransportException + exchange.sendResponseHeaders(404, 0); + } + else { + // Normal SSE response + exchange.getResponseHeaders().set("Content-Type", "text/event-stream"); + exchange.sendResponseHeaders(200, 0); + // Send a test SSE event + String sseData = "event: message\ndata: {\"jsonrpc\":\"2.0\",\"method\":\"test\",\"params\":{}}\n\n"; + exchange.getResponseBody().write(sseData.getBytes()); + } + } + else { + // POST request handling + exchange.getResponseHeaders().set("Content-Type", "application/json"); + String responseSessionId = currentServerSessionId.get(); + if (responseSessionId != null) { + exchange.getResponseHeaders().set(HttpHeaders.MCP_SESSION_ID, responseSessionId); + } + String response = "{\"jsonrpc\":\"2.0\",\"result\":{},\"id\":\"test-id\"}"; + exchange.sendResponseHeaders(200, response.length()); + exchange.getResponseBody().write(response.getBytes()); + } + exchange.close(); + }); + + // Test with session ID - should get SessionNotFoundException + serverResponseStatus.set(200); + currentServerSessionId.set("sse-session-1"); + + var transport = WebClientStreamableHttpTransport.builder(WebClient.builder().baseUrl(HOST)) + .endpoint("/mcp-sse") + .openConnectionOnStartup(true) // This will trigger GET request on connect + .build(); + + // First connect successfully + StepVerifier.create(transport.connect(msg -> msg)).verifyComplete(); + + // Wait for SSE connection to be established + boolean connected = sseConnectionLatch.await(5, TimeUnit.SECONDS); + assertThat(connected).isTrue(); + + // Send message to establish session + var testMessage = createTestMessage(); + StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); + + // Clean up + StepVerifier.create(transport.closeGracefully()).verifyComplete(); + } + + private McpSchema.JSONRPCRequest createTestMessage() { + var initializeRequest = new McpSchema.InitializeRequest(ProtocolVersions.MCP_2025_03_26, + McpSchema.ClientCapabilities.builder().roots(true).build(), + new McpSchema.Implementation("Test Client", "1.0.0")); + return new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, McpSchema.METHOD_INITIALIZE, "test-id", + initializeRequest); + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransportTest.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransportTest.java new file mode 100644 index 000000000..34e422be4 --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransportTest.java @@ -0,0 +1,77 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ +package io.modelcontextprotocol.client.transport; + +import io.modelcontextprotocol.spec.McpSchema; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.wait.strategy.Wait; +import reactor.test.StepVerifier; + +import org.springframework.web.reactive.function.client.WebClient; + +class WebClientStreamableHttpTransportTest { + + static String host = "http://localhost:3001"; + + static WebClient.Builder builder; + + @SuppressWarnings("resource") + static GenericContainer container = new GenericContainer<>("docker.io/node:lts-alpine3.23") + .withCommand("npx -y @modelcontextprotocol/server-everything@2025.12.18 streamableHttp") + .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) + .withExposedPorts(3001) + .waitingFor(Wait.forHttp("/").forStatusCode(404)); + + @BeforeAll + static void startContainer() { + container.start(); + int port = container.getMappedPort(3001); + host = "http://" + container.getHost() + ":" + port; + builder = WebClient.builder().baseUrl(host); + } + + @AfterAll + static void stopContainer() { + container.stop(); + } + + @Test + void testCloseUninitialized() { + var transport = WebClientStreamableHttpTransport.builder(builder).build(); + + StepVerifier.create(transport.closeGracefully()).verifyComplete(); + + var initializeRequest = new McpSchema.InitializeRequest(McpSchema.LATEST_PROTOCOL_VERSION, + McpSchema.ClientCapabilities.builder().roots(true).build(), + new McpSchema.Implementation("MCP Client", "0.3.1")); + var testMessage = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, McpSchema.METHOD_INITIALIZE, + "test-id", initializeRequest); + + StepVerifier.create(transport.sendMessage(testMessage)) + .expectErrorMessage("MCP session has been closed") + .verify(); + } + + @Test + void testCloseInitialized() { + var transport = WebClientStreamableHttpTransport.builder(builder).build(); + + var initializeRequest = new McpSchema.InitializeRequest(McpSchema.LATEST_PROTOCOL_VERSION, + McpSchema.ClientCapabilities.builder().roots(true).build(), + new McpSchema.Implementation("MCP Client", "0.3.1")); + var testMessage = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, McpSchema.METHOD_INITIALIZE, + "test-id", initializeRequest); + + StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); + StepVerifier.create(transport.closeGracefully()).verifyComplete(); + + StepVerifier.create(transport.sendMessage(testMessage)) + .expectErrorMatches(err -> err.getMessage().matches("MCP session with ID [a-zA-Z0-9-]* has been closed")) + .verify(); + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java index 42b91d14e..a29c9d69c 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java @@ -6,13 +6,18 @@ import java.time.Duration; import java.util.Map; +import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Function; import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.jackson.JacksonMcpJsonMapper; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest; +import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; @@ -26,6 +31,7 @@ import org.springframework.http.codec.ServerSentEvent; import org.springframework.web.reactive.function.client.WebClient; +import static io.modelcontextprotocol.util.McpJsonMapperUtils.JSON_MAPPER; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -41,8 +47,8 @@ class WebFluxSseClientTransportTests { static String host = "http://localhost:3001"; @SuppressWarnings("resource") - GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") - .withCommand("node dist/index.js sse") + static GenericContainer container = new GenericContainer<>("docker.io/node:lts-alpine3.23") + .withCommand("npx -y @modelcontextprotocol/server-everything@2025.12.18 sse") .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) .withExposedPorts(3001) .waitingFor(Wait.forHttp("/").forStatusCode(404)); @@ -51,8 +57,6 @@ class WebFluxSseClientTransportTests { private WebClient.Builder webClientBuilder; - private ObjectMapper objectMapper; - // Test class to access protected methods static class TestSseClientTransport extends WebFluxSseClientTransport { @@ -60,8 +64,8 @@ static class TestSseClientTransport extends WebFluxSseClientTransport { private Sinks.Many> events = Sinks.many().unicast().onBackpressureBuffer(); - public TestSseClientTransport(WebClient.Builder webClientBuilder, ObjectMapper objectMapper) { - super(webClientBuilder, objectMapper); + public TestSseClientTransport(WebClient.Builder webClientBuilder, McpJsonMapper jsonMapper) { + super(webClientBuilder, jsonMapper); } @Override @@ -77,6 +81,11 @@ public int getInboundMessageCount() { return inboundMessageCount.get(); } + public void simulateSseComment(String comment) { + events.tryEmitNext(ServerSentEvent.builder().comment(comment).build()); + inboundMessageCount.incrementAndGet(); + } + public void simulateEndpointEvent(String jsonMessage) { events.tryEmitNext(ServerSentEvent.builder().event("endpoint").data(jsonMessage).build()); inboundMessageCount.incrementAndGet(); @@ -89,18 +98,22 @@ public void simulateMessageEvent(String jsonMessage) { } - void startContainer() { + @BeforeAll + static void startContainer() { container.start(); int port = container.getMappedPort(3001); host = "http://" + container.getHost() + ":" + port; } + @AfterAll + static void cleanup() { + container.stop(); + } + @BeforeEach void setUp() { - startContainer(); webClientBuilder = WebClient.builder().baseUrl(host); - objectMapper = new ObjectMapper(); - transport = new TestSseClientTransport(webClientBuilder, objectMapper); + transport = new TestSseClientTransport(webClientBuilder, JSON_MAPPER); transport.connect(Function.identity()).block(); } @@ -109,11 +122,6 @@ void afterEach() { if (transport != null) { assertThatCode(() -> transport.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); } - cleanup(); - } - - void cleanup() { - container.stop(); } @Test @@ -123,12 +131,13 @@ void testEndpointEventHandling() { @Test void constructorValidation() { - assertThatThrownBy(() -> new WebFluxSseClientTransport(null)).isInstanceOf(IllegalArgumentException.class) + assertThatThrownBy(() -> new WebFluxSseClientTransport(null, JSON_MAPPER)) + .isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("WebClient.Builder must not be null"); assertThatThrownBy(() -> new WebFluxSseClientTransport(webClientBuilder, null)) .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("ObjectMapper must not be null"); + .hasMessageContaining("jsonMapper must not be null"); } @Test @@ -140,7 +149,7 @@ void testBuilderPattern() { // Test builder with custom ObjectMapper ObjectMapper customMapper = new ObjectMapper(); WebFluxSseClientTransport transport2 = WebFluxSseClientTransport.builder(webClientBuilder) - .objectMapper(customMapper) + .jsonMapper(new JacksonMcpJsonMapper(customMapper)) .build(); assertThatCode(() -> transport2.closeGracefully().block()).doesNotThrowAnyException(); @@ -152,12 +161,32 @@ void testBuilderPattern() { // Test builder with all custom parameters WebFluxSseClientTransport transport4 = WebFluxSseClientTransport.builder(webClientBuilder) - .objectMapper(customMapper) .sseEndpoint("/custom-sse") .build(); assertThatCode(() -> transport4.closeGracefully().block()).doesNotThrowAnyException(); } + @Test + void testCommentSseMessage() { + // If the line starts with a character (:) are comment lins and should be ingored + // https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation + + CopyOnWriteArrayList droppedErrors = new CopyOnWriteArrayList<>(); + reactor.core.publisher.Hooks.onErrorDropped(droppedErrors::add); + + try { + // Simulate receiving the SSE comment line + transport.simulateSseComment("sse comment"); + + StepVerifier.create(transport.closeGracefully()).verifyComplete(); + + assertThat(droppedErrors).hasSize(0); + } + finally { + reactor.core.publisher.Hooks.resetOnErrorDropped(); + } + } + @Test void testMessageProcessing() { // Create a test message diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/common/AsyncServerMcpTransportContextIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/common/AsyncServerMcpTransportContextIntegrationTests.java new file mode 100644 index 000000000..3db0bbd3a --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/common/AsyncServerMcpTransportContextIntegrationTests.java @@ -0,0 +1,269 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.common; + +import java.util.Map; +import java.util.function.BiFunction; + +import io.modelcontextprotocol.client.McpAsyncClient; +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; +import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; +import io.modelcontextprotocol.server.McpAsyncServerExchange; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.McpServerFeatures; +import io.modelcontextprotocol.server.McpStatelessServerFeatures; +import io.modelcontextprotocol.server.McpTransportContextExtractor; +import io.modelcontextprotocol.server.TestUtil; +import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider; +import io.modelcontextprotocol.server.transport.WebFluxStatelessServerTransport; +import io.modelcontextprotocol.server.transport.WebFluxStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.McpSchema; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import reactor.core.publisher.Mono; +import reactor.netty.DisposableServer; +import reactor.netty.http.server.HttpServer; +import reactor.test.StepVerifier; + +import org.springframework.http.server.reactive.HttpHandler; +import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; +import org.springframework.web.reactive.function.client.ClientRequest; +import org.springframework.web.reactive.function.client.ExchangeFilterFunction; +import org.springframework.web.reactive.function.client.WebClient; +import org.springframework.web.reactive.function.server.RouterFunction; +import org.springframework.web.reactive.function.server.RouterFunctions; +import org.springframework.web.reactive.function.server.ServerRequest; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for {@link McpTransportContext} propagation between MCP clients and + * async servers using Spring WebFlux infrastructure. + * + *

+ * This test class validates the end-to-end flow of transport context propagation in MCP + * communication for asynchronous client and server implementations. It tests various + * combinations of client types and server transport mechanisms (stateless, streamable, + * SSE) to ensure proper context handling across different configurations. + * + *

Context Propagation Flow

+ *
    + *
  1. Client sets a value in its transport context via thread-local Reactor context
  2. + *
  3. Client-side context provider extracts the value and adds it as an HTTP header to + * the request
  4. + *
  5. Server-side context extractor reads the header from the incoming request
  6. + *
  7. Server handler receives the extracted context and returns the value as the tool + * call result
  8. + *
  9. Test verifies the round-trip context propagation was successful
  10. + *
+ * + * @author Daniel Garnier-Moiroux + * @author Christian Tzolov + */ +@Timeout(15) +public class AsyncServerMcpTransportContextIntegrationTests { + + private static final int PORT = TestUtil.findAvailablePort(); + + private static final String HEADER_NAME = "x-test"; + + // Async client context provider + ExchangeFilterFunction asyncClientContextProvider = (request, next) -> Mono.deferContextual(ctx -> { + var transportContext = ctx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY); + // // do stuff with the context + var headerValue = transportContext.get("client-side-header-value"); + if (headerValue == null) { + return next.exchange(request); + } + var reqWithHeader = ClientRequest.from(request).header(HEADER_NAME, headerValue.toString()).build(); + return next.exchange(reqWithHeader); + }); + + // Tools + private final McpSchema.Tool tool = McpSchema.Tool.builder() + .name("test-tool") + .description("return the value of the x-test header from call tool request") + .build(); + + private final BiFunction> asyncStatelessHandler = ( + transportContext, request) -> { + return Mono + .just(new McpSchema.CallToolResult(transportContext.get("server-side-header-value").toString(), null)); + }; + + private final BiFunction> asyncStatefulHandler = ( + exchange, request) -> { + return asyncStatelessHandler.apply(exchange.transportContext(), request); + }; + + // Server context extractor + private final McpTransportContextExtractor serverContextExtractor = (ServerRequest r) -> { + var headerValue = r.headers().firstHeader(HEADER_NAME); + return headerValue != null ? McpTransportContext.create(Map.of("server-side-header-value", headerValue)) + : McpTransportContext.EMPTY; + }; + + // Server transports + private final WebFluxStatelessServerTransport statelessServerTransport = WebFluxStatelessServerTransport.builder() + .contextExtractor(serverContextExtractor) + .build(); + + private final WebFluxStreamableServerTransportProvider streamableServerTransport = WebFluxStreamableServerTransportProvider + .builder() + .contextExtractor(serverContextExtractor) + .build(); + + private final WebFluxSseServerTransportProvider sseServerTransport = WebFluxSseServerTransportProvider.builder() + .contextExtractor(serverContextExtractor) + .messageEndpoint("/mcp/message") + .build(); + + // Async clients + private final McpAsyncClient asyncStreamableClient = McpClient + .async(WebClientStreamableHttpTransport + .builder(WebClient.builder().baseUrl("http://localhost:" + PORT).filter(asyncClientContextProvider)) + .build()) + .build(); + + private final McpAsyncClient asyncSseClient = McpClient + .async(WebFluxSseClientTransport + .builder(WebClient.builder().baseUrl("http://localhost:" + PORT).filter(asyncClientContextProvider)) + .build()) + .build(); + + private DisposableServer httpServer; + + @AfterEach + public void after() { + if (statelessServerTransport != null) { + statelessServerTransport.closeGracefully().block(); + } + if (streamableServerTransport != null) { + streamableServerTransport.closeGracefully().block(); + } + if (sseServerTransport != null) { + sseServerTransport.closeGracefully().block(); + } + if (asyncStreamableClient != null) { + asyncStreamableClient.closeGracefully().block(); + } + if (asyncSseClient != null) { + asyncSseClient.closeGracefully().block(); + } + stopHttpServer(); + } + + @Test + void asyncClientStatelessServer() { + + startHttpServer(statelessServerTransport.getRouterFunction()); + + var mcpServer = McpServer.async(statelessServerTransport) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpStatelessServerFeatures.AsyncToolSpecification(tool, asyncStatelessHandler)) + .build(); + + StepVerifier.create(asyncStreamableClient.initialize()).assertNext(initResult -> { + assertThat(initResult).isNotNull(); + }).verifyComplete(); + + // Test tool call with context + StepVerifier + .create(asyncStreamableClient.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, + McpTransportContext.create(Map.of("client-side-header-value", "some important value"))))) + .assertNext(response -> { + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + }) + .verifyComplete(); + + mcpServer.close(); + } + + @Test + void asyncClientStreamableServer() { + + startHttpServer(streamableServerTransport.getRouterFunction()); + + var mcpServer = McpServer.async(streamableServerTransport) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpServerFeatures.AsyncToolSpecification(tool, null, asyncStatefulHandler)) + .build(); + + StepVerifier.create(asyncStreamableClient.initialize()).assertNext(initResult -> { + assertThat(initResult).isNotNull(); + }).verifyComplete(); + + // Test tool call with context + StepVerifier + .create(asyncStreamableClient.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, + McpTransportContext.create(Map.of("client-side-header-value", "some important value"))))) + .assertNext(response -> { + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + }) + .verifyComplete(); + + mcpServer.close(); + } + + @Test + void asyncClientSseServer() { + + startHttpServer(sseServerTransport.getRouterFunction()); + + var mcpServer = McpServer.async(sseServerTransport) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpServerFeatures.AsyncToolSpecification(tool, null, asyncStatefulHandler)) + .build(); + + StepVerifier.create(asyncSseClient.initialize()).assertNext(initResult -> { + assertThat(initResult).isNotNull(); + }).verifyComplete(); + + // Test tool call with context + StepVerifier + .create(asyncSseClient.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, + McpTransportContext.create(Map.of("client-side-header-value", "some important value"))))) + .assertNext(response -> { + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + }) + .verifyComplete(); + + mcpServer.close(); + } + + private void startHttpServer(RouterFunction routerFunction) { + + HttpHandler httpHandler = RouterFunctions.toHttpHandler(routerFunction); + ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); + this.httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); + } + + private void stopHttpServer() { + if (httpServer != null) { + httpServer.disposeNow(); + } + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/common/SyncServerMcpTransportContextIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/common/SyncServerMcpTransportContextIntegrationTests.java new file mode 100644 index 000000000..94e16e73e --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/common/SyncServerMcpTransportContextIntegrationTests.java @@ -0,0 +1,269 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.common; + +import java.util.Map; +import java.util.function.BiFunction; +import java.util.function.Supplier; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.McpSyncClient; +import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; +import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.McpServerFeatures; +import io.modelcontextprotocol.server.McpStatelessServerFeatures; +import io.modelcontextprotocol.server.McpSyncServerExchange; +import io.modelcontextprotocol.server.McpTransportContextExtractor; +import io.modelcontextprotocol.server.TestUtil; +import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider; +import io.modelcontextprotocol.server.transport.WebFluxStatelessServerTransport; +import io.modelcontextprotocol.server.transport.WebFluxStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.McpSchema; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import reactor.core.publisher.Mono; +import reactor.netty.DisposableServer; +import reactor.netty.http.server.HttpServer; + +import org.springframework.http.server.reactive.HttpHandler; +import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; +import org.springframework.web.reactive.function.client.ClientRequest; +import org.springframework.web.reactive.function.client.WebClient; +import org.springframework.web.reactive.function.server.RouterFunction; +import org.springframework.web.reactive.function.server.RouterFunctions; +import org.springframework.web.reactive.function.server.ServerRequest; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for {@link McpTransportContext} propagation between MCP client and + * server using synchronous operations in a Spring WebFlux environment. + *

+ * This test class validates the end-to-end flow of transport context propagation across + * different WebFlux-based MCP transport implementations + * + *

+ * The test scenario follows these steps: + *

    + *
  1. The client stores a value in a thread-local variable
  2. + *
  3. The client's transport context provider reads this value and includes it in the MCP + * context
  4. + *
  5. A WebClient filter extracts the context value and adds it as an HTTP header + * (x-test)
  6. + *
  7. The server's {@link McpTransportContextExtractor} reads the header from the + * request
  8. + *
  9. The server returns the header value as the tool call result, validating the + * round-trip
  10. + *
+ * + *

+ * This test demonstrates how custom context can be propagated through HTTP headers in a + * reactive WebFlux environment, enabling features like authentication tokens, correlation + * IDs, or other metadata to flow between MCP client and server. + * + * @author Daniel Garnier-Moiroux + * @author Christian Tzolov + * @since 1.0.0 + * @see McpTransportContext + * @see McpTransportContextExtractor + * @see WebFluxStatelessServerTransport + * @see WebFluxStreamableServerTransportProvider + * @see WebFluxSseServerTransportProvider + */ +@Timeout(15) +public class SyncServerMcpTransportContextIntegrationTests { + + private static final int PORT = TestUtil.findAvailablePort(); + + private static final ThreadLocal CLIENT_SIDE_HEADER_VALUE_HOLDER = new ThreadLocal<>(); + + private static final String HEADER_NAME = "x-test"; + + private final Supplier clientContextProvider = () -> { + var headerValue = CLIENT_SIDE_HEADER_VALUE_HOLDER.get(); + return headerValue != null ? McpTransportContext.create(Map.of("client-side-header-value", headerValue)) + : McpTransportContext.EMPTY; + }; + + private final BiFunction statelessHandler = ( + transportContext, request) -> { + return new McpSchema.CallToolResult(transportContext.get("server-side-header-value").toString(), null); + }; + + private final BiFunction statefulHandler = ( + exchange, request) -> statelessHandler.apply(exchange.transportContext(), request); + + private final McpTransportContextExtractor serverContextExtractor = (ServerRequest r) -> { + var headerValue = r.headers().firstHeader(HEADER_NAME); + return headerValue != null ? McpTransportContext.create(Map.of("server-side-header-value", headerValue)) + : McpTransportContext.EMPTY; + }; + + private final WebFluxStatelessServerTransport statelessServerTransport = WebFluxStatelessServerTransport.builder() + .contextExtractor(serverContextExtractor) + .build(); + + private final WebFluxStreamableServerTransportProvider streamableServerTransport = WebFluxStreamableServerTransportProvider + .builder() + .contextExtractor(serverContextExtractor) + .build(); + + private final WebFluxSseServerTransportProvider sseServerTransport = WebFluxSseServerTransportProvider.builder() + .contextExtractor(serverContextExtractor) + .messageEndpoint("/mcp/message") + .build(); + + private final McpSyncClient streamableClient = McpClient + .sync(WebClientStreamableHttpTransport.builder(WebClient.builder() + .baseUrl("http://localhost:" + PORT) + .filter((request, next) -> Mono.deferContextual(ctx -> { + var context = ctx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY); + // // do stuff with the context + var headerValue = context.get("client-side-header-value"); + if (headerValue == null) { + return next.exchange(request); + } + var reqWithHeader = ClientRequest.from(request).header(HEADER_NAME, headerValue.toString()).build(); + return next.exchange(reqWithHeader); + }))).build()) + .transportContextProvider(clientContextProvider) + .build(); + + private final McpSyncClient sseClient = McpClient.sync(WebFluxSseClientTransport.builder(WebClient.builder() + .baseUrl("http://localhost:" + PORT) + .filter((request, next) -> Mono.deferContextual(ctx -> { + var context = ctx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY); + // // do stuff with the context + var headerValue = context.get("client-side-header-value"); + if (headerValue == null) { + return next.exchange(request); + } + var reqWithHeader = ClientRequest.from(request).header(HEADER_NAME, headerValue.toString()).build(); + return next.exchange(reqWithHeader); + }))).build()).transportContextProvider(clientContextProvider).build(); + + private final McpSchema.Tool tool = McpSchema.Tool.builder() + .name("test-tool") + .description("return the value of the x-test header from call tool request") + .build(); + + private DisposableServer httpServer; + + @AfterEach + public void after() { + CLIENT_SIDE_HEADER_VALUE_HOLDER.remove(); + if (statelessServerTransport != null) { + statelessServerTransport.closeGracefully().block(); + } + if (streamableServerTransport != null) { + streamableServerTransport.closeGracefully().block(); + } + if (sseServerTransport != null) { + sseServerTransport.closeGracefully().block(); + } + if (streamableClient != null) { + streamableClient.closeGracefully(); + } + if (sseClient != null) { + sseClient.closeGracefully(); + } + stopHttpServer(); + } + + @Test + void statelessServer() { + + startHttpServer(statelessServerTransport.getRouterFunction()); + + var mcpServer = McpServer.sync(statelessServerTransport) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpStatelessServerFeatures.SyncToolSpecification(tool, statelessHandler)) + .build(); + + McpSchema.InitializeResult initResult = streamableClient.initialize(); + assertThat(initResult).isNotNull(); + + CLIENT_SIDE_HEADER_VALUE_HOLDER.set("some important value"); + McpSchema.CallToolResult response = streamableClient + .callTool(new McpSchema.CallToolRequest("test-tool", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + + mcpServer.close(); + } + + @Test + void streamableServer() { + + startHttpServer(streamableServerTransport.getRouterFunction()); + + var mcpServer = McpServer.sync(streamableServerTransport) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpServerFeatures.SyncToolSpecification(tool, null, statefulHandler)) + .build(); + + McpSchema.InitializeResult initResult = streamableClient.initialize(); + assertThat(initResult).isNotNull(); + + CLIENT_SIDE_HEADER_VALUE_HOLDER.set("some important value"); + McpSchema.CallToolResult response = streamableClient + .callTool(new McpSchema.CallToolRequest("test-tool", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + + mcpServer.close(); + } + + @Test + void sseServer() { + startHttpServer(sseServerTransport.getRouterFunction()); + + var mcpServer = McpServer.sync(sseServerTransport) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpServerFeatures.SyncToolSpecification(tool, null, statefulHandler)) + .build(); + + McpSchema.InitializeResult initResult = sseClient.initialize(); + assertThat(initResult).isNotNull(); + + CLIENT_SIDE_HEADER_VALUE_HOLDER.set("some important value"); + McpSchema.CallToolResult response = sseClient.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + + mcpServer.close(); + } + + private void startHttpServer(RouterFunction routerFunction) { + + HttpHandler httpHandler = RouterFunctions.toHttpHandler(routerFunction); + ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); + this.httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); + } + + private void stopHttpServer() { + if (httpServer != null) { + httpServer.disposeNow(); + } + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerTests.java index cc33e7b94..fe0314687 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerTests.java @@ -4,7 +4,6 @@ package io.modelcontextprotocol.server; -import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider; import io.modelcontextprotocol.spec.McpServerTransportProvider; import org.junit.jupiter.api.Timeout; @@ -29,10 +28,8 @@ class WebFluxSseMcpAsyncServerTests extends AbstractMcpAsyncServerTests { private DisposableServer httpServer; - @Override - protected McpServerTransportProvider createMcpTransportProvider() { - var transportProvider = new WebFluxSseServerTransportProvider.Builder().objectMapper(new ObjectMapper()) - .messageEndpoint(MESSAGE_ENDPOINT) + private McpServerTransportProvider createMcpTransportProvider() { + var transportProvider = new WebFluxSseServerTransportProvider.Builder().messageEndpoint(MESSAGE_ENDPOINT) .build(); HttpHandler httpHandler = RouterFunctions.toHttpHandler(transportProvider.getRouterFunction()); @@ -41,6 +38,11 @@ protected McpServerTransportProvider createMcpTransportProvider() { return transportProvider; } + @Override + protected McpServer.AsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(createMcpTransportProvider()); + } + @Override protected void onStart() { } diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerTests.java index 2fc104538..67ef90bdf 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerTests.java @@ -4,7 +4,6 @@ package io.modelcontextprotocol.server; -import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider; import io.modelcontextprotocol.spec.McpServerTransportProvider; import org.junit.jupiter.api.Timeout; @@ -32,10 +31,12 @@ class WebFluxSseMcpSyncServerTests extends AbstractMcpSyncServerTests { private WebFluxSseServerTransportProvider transportProvider; @Override - protected McpServerTransportProvider createMcpTransportProvider() { - transportProvider = new WebFluxSseServerTransportProvider.Builder().objectMapper(new ObjectMapper()) - .messageEndpoint(MESSAGE_ENDPOINT) - .build(); + protected McpServer.SyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(createMcpTransportProvider()); + } + + private McpServerTransportProvider createMcpTransportProvider() { + transportProvider = new WebFluxSseServerTransportProvider.Builder().messageEndpoint(MESSAGE_ENDPOINT).build(); return transportProvider; } diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxStreamableMcpAsyncServerTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxStreamableMcpAsyncServerTests.java new file mode 100644 index 000000000..9b5a80f16 --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxStreamableMcpAsyncServerTests.java @@ -0,0 +1,59 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import io.modelcontextprotocol.server.transport.WebFluxStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; +import org.junit.jupiter.api.Timeout; +import org.springframework.http.server.reactive.HttpHandler; +import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; +import org.springframework.web.reactive.function.server.RouterFunctions; +import reactor.netty.DisposableServer; +import reactor.netty.http.server.HttpServer; + +/** + * Tests for {@link McpAsyncServer} using + * {@link WebFluxStreamableServerTransportProvider}. + * + * @author Christian Tzolov + * @author Dariusz Jędrzejczyk + */ +@Timeout(15) // Giving extra time beyond the client timeout +class WebFluxStreamableMcpAsyncServerTests extends AbstractMcpAsyncServerTests { + + private static final int PORT = TestUtil.findAvailablePort(); + + private static final String MESSAGE_ENDPOINT = "/mcp/message"; + + private DisposableServer httpServer; + + private McpStreamableServerTransportProvider createMcpTransportProvider() { + var transportProvider = WebFluxStreamableServerTransportProvider.builder() + .messageEndpoint(MESSAGE_ENDPOINT) + .build(); + + HttpHandler httpHandler = RouterFunctions.toHttpHandler(transportProvider.getRouterFunction()); + ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); + httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); + return transportProvider; + } + + @Override + protected McpServer.AsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(createMcpTransportProvider()); + } + + @Override + protected void onStart() { + } + + @Override + protected void onClose() { + if (httpServer != null) { + httpServer.disposeNow(); + } + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxStreamableMcpSyncServerTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxStreamableMcpSyncServerTests.java new file mode 100644 index 000000000..6a47ba3ae --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxStreamableMcpSyncServerTests.java @@ -0,0 +1,59 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import io.modelcontextprotocol.server.transport.WebFluxStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; +import org.junit.jupiter.api.Timeout; +import org.springframework.http.server.reactive.HttpHandler; +import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; +import org.springframework.web.reactive.function.server.RouterFunctions; +import reactor.netty.DisposableServer; +import reactor.netty.http.server.HttpServer; + +/** + * Tests for {@link McpAsyncServer} using + * {@link WebFluxStreamableServerTransportProvider}. + * + * @author Christian Tzolov + * @author Dariusz Jędrzejczyk + */ +@Timeout(15) // Giving extra time beyond the client timeout +class WebFluxStreamableMcpSyncServerTests extends AbstractMcpSyncServerTests { + + private static final int PORT = TestUtil.findAvailablePort(); + + private static final String MESSAGE_ENDPOINT = "/mcp/message"; + + private DisposableServer httpServer; + + private McpStreamableServerTransportProvider createMcpTransportProvider() { + var transportProvider = WebFluxStreamableServerTransportProvider.builder() + .messageEndpoint(MESSAGE_ENDPOINT) + .build(); + + HttpHandler httpHandler = RouterFunctions.toHttpHandler(transportProvider.getRouterFunction()); + ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); + httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); + return transportProvider; + } + + @Override + protected McpServer.SyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(createMcpTransportProvider()); + } + + @Override + protected void onStart() { + } + + @Override + protected void onClose() { + if (httpServer != null) { + httpServer.disposeNow(); + } + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/transport/BlockingInputStream.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/transport/BlockingInputStream.java index 0ab72a99f..dfb004e9b 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/transport/BlockingInputStream.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/transport/BlockingInputStream.java @@ -1,6 +1,7 @@ /* * Copyright 2024 - 2024 the original author or authors. */ + package io.modelcontextprotocol.server.transport; import java.io.IOException; diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/utils/McpJsonMapperUtils.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/utils/McpJsonMapperUtils.java new file mode 100644 index 000000000..67347573c --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/utils/McpJsonMapperUtils.java @@ -0,0 +1,12 @@ +package io.modelcontextprotocol.utils; + +import io.modelcontextprotocol.json.McpJsonMapper; + +public final class McpJsonMapperUtils { + + private McpJsonMapperUtils() { + } + + public static final McpJsonMapper JSON_MAPPER = McpJsonMapper.createDefault(); + +} \ No newline at end of file diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/utils/McpTestRequestRecordingExchangeFilterFunction.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/utils/McpTestRequestRecordingExchangeFilterFunction.java new file mode 100644 index 000000000..55129d481 --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/utils/McpTestRequestRecordingExchangeFilterFunction.java @@ -0,0 +1,54 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package io.modelcontextprotocol.utils; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import reactor.core.publisher.Mono; + +import org.springframework.http.HttpMethod; +import org.springframework.web.reactive.function.server.HandlerFilterFunction; +import org.springframework.web.reactive.function.server.HandlerFunction; +import org.springframework.web.reactive.function.server.ServerRequest; +import org.springframework.web.reactive.function.server.ServerResponse; + +/** + * Simple {@link HandlerFilterFunction} which records calls made to an MCP server. + * + * @author Daniel Garnier-Moiroux + */ +public class McpTestRequestRecordingExchangeFilterFunction implements HandlerFilterFunction { + + private final List calls = new ArrayList<>(); + + @Override + public Mono filter(ServerRequest request, HandlerFunction next) { + Map headers = request.headers() + .asHttpHeaders() + .keySet() + .stream() + .collect(Collectors.toMap(String::toLowerCase, k -> String.join(",", request.headers().header(k)))); + + var cr = request.bodyToMono(String.class).defaultIfEmpty("").map(body -> { + this.calls.add(new Call(request.method(), headers, body)); + return ServerRequest.from(request).body(body).build(); + }); + + return cr.flatMap(next::handle); + + } + + public List getCalls() { + return List.copyOf(calls); + } + + public record Call(HttpMethod method, Map headers, String body) { + + } + +} diff --git a/mcp-spring/mcp-spring-webmvc/pom.xml b/mcp-spring/mcp-spring-webmvc/pom.xml index 48d1c3465..df18b1b8b 100644 --- a/mcp-spring/mcp-spring-webmvc/pom.xml +++ b/mcp-spring/mcp-spring-webmvc/pom.xml @@ -6,13 +6,13 @@ io.modelcontextprotocol.sdk mcp-parent - 0.11.0-SNAPSHOT + 0.18.0-SNAPSHOT ../../pom.xml mcp-spring-webmvc jar - Spring Web MVC implementation of the Java MCP SSE transport - + Spring Web MVC transports + Web MVC implementation for the SSE and Streamable Http Server transports https://github.com/modelcontextprotocol/java-sdk @@ -22,23 +22,36 @@ - + + io.modelcontextprotocol.sdk + mcp-json-jackson2 + 0.18.0-SNAPSHOT + + + io.modelcontextprotocol.sdk mcp - 0.11.0-SNAPSHOT + 0.18.0-SNAPSHOT + + + + org.springframework + spring-webmvc + ${springframework.version} io.modelcontextprotocol.sdk mcp-test - 0.11.0-SNAPSHOT + 0.18.0-SNAPSHOT test - org.springframework - spring-webmvc - ${springframework.version} + io.modelcontextprotocol.sdk + mcp-spring-webflux + 0.18.0-SNAPSHOT + test @@ -128,7 +141,14 @@ test + + net.javacrumbs.json-unit + json-unit-assertj + ${json-unit-assertj.version} + test + + - \ No newline at end of file + diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java index fc86cfaa0..6c35de56d 100644 --- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java @@ -6,18 +6,22 @@ import java.io.IOException; import java.time.Duration; -import java.util.Map; -import java.util.UUID; +import java.util.List; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.locks.ReentrantLock; -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.TypeRef; +import io.modelcontextprotocol.server.McpTransportContextExtractor; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpServerSession; import io.modelcontextprotocol.spec.McpServerTransport; import io.modelcontextprotocol.spec.McpServerTransportProvider; -import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.spec.ProtocolVersions; import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.KeepAliveScheduler; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; @@ -29,6 +33,7 @@ import org.springframework.web.servlet.function.ServerRequest; import org.springframework.web.servlet.function.ServerResponse; import org.springframework.web.servlet.function.ServerResponse.SseBuilder; +import org.springframework.web.util.UriComponentsBuilder; /** * Server-side implementation of the Model Context Protocol (MCP) transport layer using @@ -80,12 +85,14 @@ public class WebMvcSseServerTransportProvider implements McpServerTransportProvi */ public static final String ENDPOINT_EVENT_TYPE = "endpoint"; + public static final String SESSION_ID = "sessionId"; + /** * Default SSE endpoint path as specified by the MCP transport specification. */ public static final String DEFAULT_SSE_ENDPOINT = "/sse"; - private final ObjectMapper objectMapper; + private final McpJsonMapper jsonMapper; private final String messageEndpoint; @@ -102,42 +109,18 @@ public class WebMvcSseServerTransportProvider implements McpServerTransportProvi */ private final ConcurrentHashMap sessions = new ConcurrentHashMap<>(); + private McpTransportContextExtractor contextExtractor; + /** * Flag indicating if the transport is shutting down. */ private volatile boolean isClosing = false; - /** - * Constructs a new WebMvcSseServerTransportProvider instance with the default SSE - * endpoint. - * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization - * of messages. - * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC - * messages via HTTP POST. This endpoint will be communicated to clients through the - * SSE connection's initial endpoint event. - * @throws IllegalArgumentException if either objectMapper or messageEndpoint is null - */ - public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint) { - this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT); - } - - /** - * Constructs a new WebMvcSseServerTransportProvider instance. - * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization - * of messages. - * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC - * messages via HTTP POST. This endpoint will be communicated to clients through the - * SSE connection's initial endpoint event. - * @param sseEndpoint The endpoint URI where clients establish their SSE connections. - * @throws IllegalArgumentException if any parameter is null - */ - public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) { - this(objectMapper, "", messageEndpoint, sseEndpoint); - } + private KeepAliveScheduler keepAliveScheduler; /** * Constructs a new WebMvcSseServerTransportProvider instance. - * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + * @param jsonMapper The McpJsonMapper to use for JSON serialization/deserialization * of messages. * @param baseUrl The base URL for the message endpoint, used to construct the full * endpoint URL for clients. @@ -145,23 +128,45 @@ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messag * messages via HTTP POST. This endpoint will be communicated to clients through the * SSE connection's initial endpoint event. * @param sseEndpoint The endpoint URI where clients establish their SSE connections. + * @param keepAliveInterval The interval for sending keep-alive messages to clients. + * @param contextExtractor The contextExtractor to fill in a + * {@link McpTransportContext}. * @throws IllegalArgumentException if any parameter is null */ - public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, - String sseEndpoint) { - Assert.notNull(objectMapper, "ObjectMapper must not be null"); + private WebMvcSseServerTransportProvider(McpJsonMapper jsonMapper, String baseUrl, String messageEndpoint, + String sseEndpoint, Duration keepAliveInterval, + McpTransportContextExtractor contextExtractor) { + Assert.notNull(jsonMapper, "McpJsonMapper must not be null"); Assert.notNull(baseUrl, "Message base URL must not be null"); Assert.notNull(messageEndpoint, "Message endpoint must not be null"); Assert.notNull(sseEndpoint, "SSE endpoint must not be null"); + Assert.notNull(contextExtractor, "Context extractor must not be null"); - this.objectMapper = objectMapper; + this.jsonMapper = jsonMapper; this.baseUrl = baseUrl; this.messageEndpoint = messageEndpoint; this.sseEndpoint = sseEndpoint; + this.contextExtractor = contextExtractor; this.routerFunction = RouterFunctions.route() .GET(this.sseEndpoint, this::handleSseConnection) .POST(this.messageEndpoint, this::handleMessage) .build(); + + if (keepAliveInterval != null) { + + this.keepAliveScheduler = KeepAliveScheduler + .builder(() -> (isClosing) ? Flux.empty() : Flux.fromIterable(sessions.values())) + .initialDelay(keepAliveInterval) + .interval(keepAliveInterval) + .build(); + + this.keepAliveScheduler.start(); + } + } + + @Override + public List protocolVersions() { + return List.of(ProtocolVersions.MCP_2024_11_05); } @Override @@ -209,10 +214,13 @@ public Mono closeGracefully() { return Flux.fromIterable(sessions.values()).doFirst(() -> { this.isClosing = true; logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size()); - }) - .flatMap(McpServerSession::closeGracefully) - .then() - .doOnSuccess(v -> logger.debug("Graceful shutdown completed")); + }).flatMap(McpServerSession::closeGracefully).then().doOnSuccess(v -> { + logger.debug("Graceful shutdown completed"); + sessions.clear(); + if (this.keepAliveScheduler != null) { + this.keepAliveScheduler.shutdown(); + } + }); } /** @@ -247,41 +255,46 @@ private ServerResponse handleSseConnection(ServerRequest request) { return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down"); } - String sessionId = UUID.randomUUID().toString(); - logger.debug("Creating new SSE connection for session: {}", sessionId); - // Send initial endpoint event - try { - return ServerResponse.sse(sseBuilder -> { - sseBuilder.onComplete(() -> { - logger.debug("SSE connection completed for session: {}", sessionId); - sessions.remove(sessionId); - }); - sseBuilder.onTimeout(() -> { - logger.debug("SSE connection timed out for session: {}", sessionId); - sessions.remove(sessionId); - }); - - WebMvcMcpSessionTransport sessionTransport = new WebMvcMcpSessionTransport(sessionId, sseBuilder); - McpServerSession session = sessionFactory.create(sessionTransport); - this.sessions.put(sessionId, session); + return ServerResponse.sse(sseBuilder -> { + WebMvcMcpSessionTransport sessionTransport = new WebMvcMcpSessionTransport(sseBuilder); + McpServerSession session = sessionFactory.create(sessionTransport); + String sessionId = session.getId(); + logger.debug("Creating new SSE connection for session: {}", sessionId); + sseBuilder.onComplete(() -> { + logger.debug("SSE connection completed for session: {}", sessionId); + sessions.remove(sessionId); + }); + sseBuilder.onTimeout(() -> { + logger.debug("SSE connection timed out for session: {}", sessionId); + sessions.remove(sessionId); + }); + this.sessions.put(sessionId, session); - try { - sseBuilder.id(sessionId) - .event(ENDPOINT_EVENT_TYPE) - .data(this.baseUrl + this.messageEndpoint + "?sessionId=" + sessionId); - } - catch (Exception e) { - logger.error("Failed to send initial endpoint event: {}", e.getMessage()); - sseBuilder.error(e); - } - }, Duration.ZERO); - } - catch (Exception e) { - logger.error("Failed to send initial endpoint event to session {}: {}", sessionId, e.getMessage()); - sessions.remove(sessionId); - return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).build(); - } + try { + sseBuilder.event(ENDPOINT_EVENT_TYPE).data(buildEndpointUrl(sessionId)); + } + catch (Exception e) { + logger.error("Failed to send initial endpoint event: {}", e.getMessage()); + this.sessions.remove(sessionId); + sseBuilder.error(e); + } + }, Duration.ZERO); + } + + /** + * Constructs the full message endpoint URL by combining the base URL, message path, + * and the required session_id query parameter. + * @param sessionId the unique session identifier + * @return the fully qualified endpoint URL as a string + */ + private String buildEndpointUrl(String sessionId) { + // for WebMVC compatibility + return UriComponentsBuilder.fromUriString(this.baseUrl) + .path(this.messageEndpoint) + .queryParam(SESSION_ID, sessionId) + .build() + .toUriString(); } /** @@ -300,11 +313,11 @@ private ServerResponse handleMessage(ServerRequest request) { return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down"); } - if (request.param("sessionId").isEmpty()) { + if (request.param(SESSION_ID).isEmpty()) { return ServerResponse.badRequest().body(new McpError("Session ID missing in message endpoint")); } - String sessionId = request.param("sessionId").get(); + String sessionId = request.param(SESSION_ID).get(); McpServerSession session = sessions.get(sessionId); if (session == null) { @@ -312,11 +325,16 @@ private ServerResponse handleMessage(ServerRequest request) { } try { + final McpTransportContext transportContext = this.contextExtractor.extract(request); + String body = request.body(String.class); - McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body); + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(jsonMapper, body); // Process the message through the session's handle method - session.handle(message).block(); // Block for WebMVC compatibility + session.handle(message).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)).block(); // Block + // for + // WebMVC + // compatibility return ServerResponse.ok().build(); } @@ -336,19 +354,20 @@ private ServerResponse handleMessage(ServerRequest request) { */ private class WebMvcMcpSessionTransport implements McpServerTransport { - private final String sessionId; - private final SseBuilder sseBuilder; /** - * Creates a new session transport with the specified ID and SSE builder. - * @param sessionId The unique identifier for this session + * Lock to ensure thread-safe access to the SSE builder when sending messages. + * This prevents concurrent modifications that could lead to corrupted SSE events. + */ + private final ReentrantLock sseBuilderLock = new ReentrantLock(); + + /** + * Creates a new session transport with the specified SSE builder. * @param sseBuilder The SSE builder for sending server events to the client */ - WebMvcMcpSessionTransport(String sessionId, SseBuilder sseBuilder) { - this.sessionId = sessionId; + WebMvcMcpSessionTransport(SseBuilder sseBuilder) { this.sseBuilder = sseBuilder; - logger.debug("Session transport {} initialized with SSE builder", sessionId); } /** @@ -359,28 +378,31 @@ private class WebMvcMcpSessionTransport implements McpServerTransport { @Override public Mono sendMessage(McpSchema.JSONRPCMessage message) { return Mono.fromRunnable(() -> { + sseBuilderLock.lock(); try { - String jsonText = objectMapper.writeValueAsString(message); - sseBuilder.id(sessionId).event(MESSAGE_EVENT_TYPE).data(jsonText); - logger.debug("Message sent to session {}", sessionId); + String jsonText = jsonMapper.writeValueAsString(message); + sseBuilder.event(MESSAGE_EVENT_TYPE).data(jsonText); } catch (Exception e) { - logger.error("Failed to send message to session {}: {}", sessionId, e.getMessage()); + logger.error("Failed to send message: {}", e.getMessage()); sseBuilder.error(e); } + finally { + sseBuilderLock.unlock(); + } }); } /** - * Converts data from one type to another using the configured ObjectMapper. + * Converts data from one type to another using the configured McpJsonMapper. * @param data The source data object to convert * @param typeRef The target type reference - * @return The converted object of type T * @param The target type + * @return The converted object of type T */ @Override - public T unmarshalFrom(Object data, TypeReference typeRef) { - return objectMapper.convertValue(data, typeRef); + public T unmarshalFrom(Object data, TypeRef typeRef) { + return jsonMapper.convertValue(data, typeRef); } /** @@ -390,13 +412,15 @@ public T unmarshalFrom(Object data, TypeReference typeRef) { @Override public Mono closeGracefully() { return Mono.fromRunnable(() -> { - logger.debug("Closing session transport: {}", sessionId); + sseBuilderLock.lock(); try { sseBuilder.complete(); - logger.debug("Successfully completed SSE builder for session {}", sessionId); } catch (Exception e) { - logger.warn("Failed to complete SSE builder for session {}: {}", sessionId, e.getMessage()); + logger.warn("Failed to complete SSE builder: {}", e.getMessage()); + } + finally { + sseBuilderLock.unlock(); } }); } @@ -406,13 +430,137 @@ public Mono closeGracefully() { */ @Override public void close() { + sseBuilderLock.lock(); try { sseBuilder.complete(); - logger.debug("Successfully completed SSE builder for session {}", sessionId); } catch (Exception e) { - logger.warn("Failed to complete SSE builder for session {}: {}", sessionId, e.getMessage()); + logger.warn("Failed to complete SSE builder: {}", e.getMessage()); + } + finally { + sseBuilderLock.unlock(); + } + } + + } + + /** + * Creates a new Builder instance for configuring and creating instances of + * WebMvcSseServerTransportProvider. + * @return A new Builder instance + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for creating instances of WebMvcSseServerTransportProvider. + *

+ * This builder provides a fluent API for configuring and creating instances of + * WebMvcSseServerTransportProvider with custom settings. + */ + public static class Builder { + + private McpJsonMapper jsonMapper; + + private String baseUrl = ""; + + private String messageEndpoint; + + private String sseEndpoint = DEFAULT_SSE_ENDPOINT; + + private Duration keepAliveInterval; + + private McpTransportContextExtractor contextExtractor = ( + serverRequest) -> McpTransportContext.EMPTY; + + /** + * Sets the JSON object mapper to use for message serialization/deserialization. + * @param jsonMapper The object mapper to use + * @return This builder instance for method chaining + */ + public Builder jsonMapper(McpJsonMapper jsonMapper) { + Assert.notNull(jsonMapper, "McpJsonMapper must not be null"); + this.jsonMapper = jsonMapper; + return this; + } + + /** + * Sets the base URL for the server transport. + * @param baseUrl The base URL to use + * @return This builder instance for method chaining + */ + public Builder baseUrl(String baseUrl) { + Assert.notNull(baseUrl, "Base URL must not be null"); + this.baseUrl = baseUrl; + return this; + } + + /** + * Sets the endpoint path where clients will send their messages. + * @param messageEndpoint The message endpoint path + * @return This builder instance for method chaining + */ + public Builder messageEndpoint(String messageEndpoint) { + Assert.hasText(messageEndpoint, "Message endpoint must not be empty"); + this.messageEndpoint = messageEndpoint; + return this; + } + + /** + * Sets the endpoint path where clients will establish SSE connections. + *

+ * If not specified, the default value of {@link #DEFAULT_SSE_ENDPOINT} will be + * used. + * @param sseEndpoint The SSE endpoint path + * @return This builder instance for method chaining + */ + public Builder sseEndpoint(String sseEndpoint) { + Assert.hasText(sseEndpoint, "SSE endpoint must not be empty"); + this.sseEndpoint = sseEndpoint; + return this; + } + + /** + * Sets the interval for keep-alive pings. + *

+ * If not specified, keep-alive pings will be disabled. + * @param keepAliveInterval The interval duration for keep-alive pings + * @return This builder instance for method chaining + */ + public Builder keepAliveInterval(Duration keepAliveInterval) { + this.keepAliveInterval = keepAliveInterval; + return this; + } + + /** + * Sets the context extractor that allows providing the MCP feature + * implementations to inspect HTTP transport level metadata that was present at + * HTTP request processing time. This allows to extract custom headers and other + * useful data for use during execution later on in the process. + * @param contextExtractor The contextExtractor to fill in a + * {@link McpTransportContext}. + * @return this builder instance + * @throws IllegalArgumentException if contextExtractor is null + */ + public Builder contextExtractor(McpTransportContextExtractor contextExtractor) { + Assert.notNull(contextExtractor, "contextExtractor must not be null"); + this.contextExtractor = contextExtractor; + return this; + } + + /** + * Builds a new instance of WebMvcSseServerTransportProvider with the configured + * settings. + * @return A new WebMvcSseServerTransportProvider instance + * @throws IllegalStateException if jsonMapper or messageEndpoint is not set + */ + public WebMvcSseServerTransportProvider build() { + if (messageEndpoint == null) { + throw new IllegalStateException("MessageEndpoint must be set"); } + return new WebMvcSseServerTransportProvider(jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, + baseUrl, messageEndpoint, sseEndpoint, keepAliveInterval, contextExtractor); } } diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStatelessServerTransport.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStatelessServerTransport.java new file mode 100644 index 000000000..4223084ff --- /dev/null +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStatelessServerTransport.java @@ -0,0 +1,240 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.server.McpStatelessServerHandler; +import io.modelcontextprotocol.server.McpTransportContextExtractor; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpStatelessServerTransport; +import io.modelcontextprotocol.util.Assert; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.web.servlet.function.RouterFunction; +import org.springframework.web.servlet.function.RouterFunctions; +import org.springframework.web.servlet.function.ServerRequest; +import org.springframework.web.servlet.function.ServerResponse; +import reactor.core.publisher.Mono; + +import java.io.IOException; +import java.util.List; + +/** + * Implementation of a WebMVC based {@link McpStatelessServerTransport}. + * + *

+ * This is the non-reactive version of + * {@link io.modelcontextprotocol.server.transport.WebFluxStatelessServerTransport} + * + * @author Christian Tzolov + */ +public class WebMvcStatelessServerTransport implements McpStatelessServerTransport { + + private static final Logger logger = LoggerFactory.getLogger(WebMvcStatelessServerTransport.class); + + private final McpJsonMapper jsonMapper; + + private final String mcpEndpoint; + + private final RouterFunction routerFunction; + + private McpStatelessServerHandler mcpHandler; + + private McpTransportContextExtractor contextExtractor; + + private volatile boolean isClosing = false; + + private WebMvcStatelessServerTransport(McpJsonMapper jsonMapper, String mcpEndpoint, + McpTransportContextExtractor contextExtractor) { + Assert.notNull(jsonMapper, "jsonMapper must not be null"); + Assert.notNull(mcpEndpoint, "mcpEndpoint must not be null"); + Assert.notNull(contextExtractor, "contextExtractor must not be null"); + + this.jsonMapper = jsonMapper; + this.mcpEndpoint = mcpEndpoint; + this.contextExtractor = contextExtractor; + this.routerFunction = RouterFunctions.route() + .GET(this.mcpEndpoint, this::handleGet) + .POST(this.mcpEndpoint, this::handlePost) + .build(); + } + + @Override + public void setMcpHandler(McpStatelessServerHandler mcpHandler) { + this.mcpHandler = mcpHandler; + } + + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> this.isClosing = true); + } + + /** + * Returns the WebMVC router function that defines the transport's HTTP endpoints. + * This router function should be integrated into the application's web configuration. + * + *

+ * The router function defines one endpoint handling two HTTP methods: + *

    + *
  • GET {messageEndpoint} - Unsupported, returns 405 METHOD NOT ALLOWED
  • + *
  • POST {messageEndpoint} - For handling client requests and notifications
  • + *
+ * @return The configured {@link RouterFunction} for handling HTTP requests + */ + public RouterFunction getRouterFunction() { + return this.routerFunction; + } + + private ServerResponse handleGet(ServerRequest request) { + return ServerResponse.status(HttpStatus.METHOD_NOT_ALLOWED).build(); + } + + private ServerResponse handlePost(ServerRequest request) { + if (isClosing) { + return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down"); + } + + McpTransportContext transportContext = this.contextExtractor.extract(request); + + List acceptHeaders = request.headers().asHttpHeaders().getAccept(); + if (!(acceptHeaders.contains(MediaType.APPLICATION_JSON) + && acceptHeaders.contains(MediaType.TEXT_EVENT_STREAM))) { + return ServerResponse.badRequest().build(); + } + + try { + String body = request.body(String.class); + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(jsonMapper, body); + + if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest) { + try { + McpSchema.JSONRPCResponse jsonrpcResponse = this.mcpHandler + .handleRequest(transportContext, jsonrpcRequest) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) + .block(); + return ServerResponse.ok().contentType(MediaType.APPLICATION_JSON).body(jsonrpcResponse); + } + catch (Exception e) { + logger.error("Failed to handle request: {}", e.getMessage()); + return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR) + .body(new McpError("Failed to handle request: " + e.getMessage())); + } + } + else if (message instanceof McpSchema.JSONRPCNotification jsonrpcNotification) { + try { + this.mcpHandler.handleNotification(transportContext, jsonrpcNotification) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) + .block(); + return ServerResponse.accepted().build(); + } + catch (Exception e) { + logger.error("Failed to handle notification: {}", e.getMessage()); + return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR) + .body(new McpError("Failed to handle notification: " + e.getMessage())); + } + } + else { + return ServerResponse.badRequest() + .body(new McpError("The server accepts either requests or notifications")); + } + } + catch (IllegalArgumentException | IOException e) { + logger.error("Failed to deserialize message: {}", e.getMessage()); + return ServerResponse.badRequest().body(new McpError("Invalid message format")); + } + catch (Exception e) { + logger.error("Unexpected error handling message: {}", e.getMessage()); + return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR) + .body(new McpError("Unexpected error: " + e.getMessage())); + } + } + + /** + * Create a builder for the server. + * @return a fresh {@link Builder} instance. + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for creating instances of {@link WebMvcStatelessServerTransport}. + *

+ * This builder provides a fluent API for configuring and creating instances of + * WebMvcStatelessServerTransport with custom settings. + */ + public static class Builder { + + private McpJsonMapper jsonMapper; + + private String mcpEndpoint = "/mcp"; + + private McpTransportContextExtractor contextExtractor = ( + serverRequest) -> McpTransportContext.EMPTY; + + private Builder() { + // used by a static method + } + + /** + * Sets the ObjectMapper to use for JSON serialization/deserialization of MCP + * messages. + * @param jsonMapper The ObjectMapper instance. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if jsonMapper is null + */ + public Builder jsonMapper(McpJsonMapper jsonMapper) { + Assert.notNull(jsonMapper, "ObjectMapper must not be null"); + this.jsonMapper = jsonMapper; + return this; + } + + /** + * Sets the endpoint URI where clients should send their JSON-RPC messages. + * @param messageEndpoint The message endpoint URI. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if messageEndpoint is null + */ + public Builder messageEndpoint(String messageEndpoint) { + Assert.notNull(messageEndpoint, "Message endpoint must not be null"); + this.mcpEndpoint = messageEndpoint; + return this; + } + + /** + * Sets the context extractor that allows providing the MCP feature + * implementations to inspect HTTP transport level metadata that was present at + * HTTP request processing time. This allows to extract custom headers and other + * useful data for use during execution later on in the process. + * @param contextExtractor The contextExtractor to fill in a + * {@link McpTransportContext}. + * @return this builder instance + * @throws IllegalArgumentException if contextExtractor is null + */ + public Builder contextExtractor(McpTransportContextExtractor contextExtractor) { + Assert.notNull(contextExtractor, "Context extractor must not be null"); + this.contextExtractor = contextExtractor; + return this; + } + + /** + * Builds a new instance of {@link WebMvcStatelessServerTransport} with the + * configured settings. + * @return A new WebMvcStatelessServerTransport instance + * @throws IllegalStateException if required parameters are not set + */ + public WebMvcStatelessServerTransport build() { + Assert.notNull(mcpEndpoint, "Message endpoint must be set"); + return new WebMvcStatelessServerTransport(jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, + mcpEndpoint, contextExtractor); + } + + } + +} diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java new file mode 100644 index 000000000..f2a58d4d8 --- /dev/null +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java @@ -0,0 +1,690 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import java.io.IOException; +import java.time.Duration; +import java.util.List; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.locks.ReentrantLock; + +import io.modelcontextprotocol.json.McpJsonMapper; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.web.servlet.function.RouterFunction; +import org.springframework.web.servlet.function.RouterFunctions; +import org.springframework.web.servlet.function.ServerRequest; +import org.springframework.web.servlet.function.ServerResponse; +import org.springframework.web.servlet.function.ServerResponse.SseBuilder; + +import io.modelcontextprotocol.json.TypeRef; + +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.server.McpTransportContextExtractor; +import io.modelcontextprotocol.spec.HttpHeaders; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpStreamableServerSession; +import io.modelcontextprotocol.spec.McpStreamableServerTransport; +import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.ProtocolVersions; +import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.KeepAliveScheduler; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +/** + * Server-side implementation of the Model Context Protocol (MCP) streamable transport + * layer using HTTP with Server-Sent Events (SSE) through Spring WebMVC. This + * implementation provides a bridge between synchronous WebMVC operations and reactive + * programming patterns to maintain compatibility with the reactive transport interface. + * + *

+ * This is the non-reactive version of + * {@link io.modelcontextprotocol.server.transport.WebFluxStreamableServerTransportProvider} + * + * @author Christian Tzolov + * @author Dariusz Jędrzejczyk + * @see McpStreamableServerTransportProvider + * @see RouterFunction + */ +public class WebMvcStreamableServerTransportProvider implements McpStreamableServerTransportProvider { + + private static final Logger logger = LoggerFactory.getLogger(WebMvcStreamableServerTransportProvider.class); + + /** + * Event type for JSON-RPC messages sent through the SSE connection. + */ + public static final String MESSAGE_EVENT_TYPE = "message"; + + /** + * Event type for sending the message endpoint URI to clients. + */ + public static final String ENDPOINT_EVENT_TYPE = "endpoint"; + + /** + * Default base URL for the message endpoint. + */ + public static final String DEFAULT_BASE_URL = ""; + + /** + * The endpoint URI where clients should send their JSON-RPC messages. Defaults to + * "/mcp". + */ + private final String mcpEndpoint; + + /** + * Flag indicating whether DELETE requests are disallowed on the endpoint. + */ + private final boolean disallowDelete; + + private final McpJsonMapper jsonMapper; + + private final RouterFunction routerFunction; + + private McpStreamableServerSession.Factory sessionFactory; + + /** + * Map of active client sessions, keyed by mcp-session-id. + */ + private final ConcurrentHashMap sessions = new ConcurrentHashMap<>(); + + private McpTransportContextExtractor contextExtractor; + + /** + * Flag indicating if the transport is shutting down. + */ + private volatile boolean isClosing = false; + + private KeepAliveScheduler keepAliveScheduler; + + /** + * Constructs a new WebMvcStreamableServerTransportProvider instance. + * @param jsonMapper The McpJsonMapper to use for JSON serialization/deserialization + * of messages. + * @param baseUrl The base URL for the message endpoint, used to construct the full + * endpoint URL for clients. + * @param mcpEndpoint The endpoint URI where clients should send their JSON-RPC + * messages via HTTP. This endpoint will handle GET, POST, and DELETE requests. + * @param disallowDelete Whether to disallow DELETE requests on the endpoint. + * @throws IllegalArgumentException if any parameter is null + */ + private WebMvcStreamableServerTransportProvider(McpJsonMapper jsonMapper, String mcpEndpoint, + boolean disallowDelete, McpTransportContextExtractor contextExtractor, + Duration keepAliveInterval) { + Assert.notNull(jsonMapper, "McpJsonMapper must not be null"); + Assert.notNull(mcpEndpoint, "MCP endpoint must not be null"); + Assert.notNull(contextExtractor, "McpTransportContextExtractor must not be null"); + + this.jsonMapper = jsonMapper; + this.mcpEndpoint = mcpEndpoint; + this.disallowDelete = disallowDelete; + this.contextExtractor = contextExtractor; + this.routerFunction = RouterFunctions.route() + .GET(this.mcpEndpoint, this::handleGet) + .POST(this.mcpEndpoint, this::handlePost) + .DELETE(this.mcpEndpoint, this::handleDelete) + .build(); + + if (keepAliveInterval != null) { + this.keepAliveScheduler = KeepAliveScheduler + .builder(() -> (isClosing) ? Flux.empty() : Flux.fromIterable(this.sessions.values())) + .initialDelay(keepAliveInterval) + .interval(keepAliveInterval) + .build(); + + this.keepAliveScheduler.start(); + } + } + + @Override + public List protocolVersions() { + return List.of(ProtocolVersions.MCP_2024_11_05, ProtocolVersions.MCP_2025_03_26, + ProtocolVersions.MCP_2025_06_18); + } + + @Override + public void setSessionFactory(McpStreamableServerSession.Factory sessionFactory) { + this.sessionFactory = sessionFactory; + } + + /** + * Broadcasts a notification to all connected clients through their SSE connections. + * If any errors occur during sending to a particular client, they are logged but + * don't prevent sending to other clients. + * @param method The method name for the notification + * @param params The parameters for the notification + * @return A Mono that completes when the broadcast attempt is finished + */ + @Override + public Mono notifyClients(String method, Object params) { + if (this.sessions.isEmpty()) { + logger.debug("No active sessions to broadcast message to"); + return Mono.empty(); + } + + logger.debug("Attempting to broadcast message to {} active sessions", this.sessions.size()); + + return Mono.fromRunnable(() -> { + this.sessions.values().parallelStream().forEach(session -> { + try { + session.sendNotification(method, params).block(); + } + catch (Exception e) { + logger.error("Failed to send message to session {}: {}", session.getId(), e.getMessage()); + } + }); + }); + } + + /** + * Initiates a graceful shutdown of the transport. + * @return A Mono that completes when all cleanup operations are finished + */ + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> { + this.isClosing = true; + logger.debug("Initiating graceful shutdown with {} active sessions", this.sessions.size()); + + this.sessions.values().parallelStream().forEach(session -> { + try { + session.closeGracefully().block(); + } + catch (Exception e) { + logger.error("Failed to close session {}: {}", session.getId(), e.getMessage()); + } + }); + + this.sessions.clear(); + logger.debug("Graceful shutdown completed"); + }).then().doOnSuccess(v -> { + if (this.keepAliveScheduler != null) { + this.keepAliveScheduler.shutdown(); + } + }); + } + + /** + * Returns the RouterFunction that defines the HTTP endpoints for this transport. The + * router function handles three endpoints: + *

    + *
  • GET [mcpEndpoint] - For establishing SSE connections and message replay
  • + *
  • POST [mcpEndpoint] - For receiving JSON-RPC messages from clients
  • + *
  • DELETE [mcpEndpoint] - For session deletion (if enabled)
  • + *
+ * @return The configured RouterFunction for handling HTTP requests + */ + public RouterFunction getRouterFunction() { + return this.routerFunction; + } + + /** + * Setup the listening SSE connections and message replay. + * @param request The incoming server request + * @return A ServerResponse configured for SSE communication, or an error response + */ + private ServerResponse handleGet(ServerRequest request) { + if (this.isClosing) { + return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down"); + } + + List acceptHeaders = request.headers().asHttpHeaders().getAccept(); + if (!acceptHeaders.contains(MediaType.TEXT_EVENT_STREAM)) { + return ServerResponse.badRequest().body("Invalid Accept header. Expected TEXT_EVENT_STREAM"); + } + + McpTransportContext transportContext = this.contextExtractor.extract(request); + + if (request.headers().header(HttpHeaders.MCP_SESSION_ID).isEmpty()) { + return ServerResponse.badRequest().body("Session ID required in mcp-session-id header"); + } + + String sessionId = request.headers().asHttpHeaders().getFirst(HttpHeaders.MCP_SESSION_ID); + McpStreamableServerSession session = this.sessions.get(sessionId); + + if (session == null) { + return ServerResponse.notFound().build(); + } + + logger.debug("Handling GET request for session: {}", sessionId); + + try { + return ServerResponse.sse(sseBuilder -> { + sseBuilder.onTimeout(() -> { + logger.debug("SSE connection timed out for session: {}", sessionId); + }); + + WebMvcStreamableMcpSessionTransport sessionTransport = new WebMvcStreamableMcpSessionTransport( + sessionId, sseBuilder); + + // Check if this is a replay request + if (!request.headers().header(HttpHeaders.LAST_EVENT_ID).isEmpty()) { + String lastId = request.headers().asHttpHeaders().getFirst(HttpHeaders.LAST_EVENT_ID); + + try { + session.replay(lastId) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) + .toIterable() + .forEach(message -> { + try { + sessionTransport.sendMessage(message) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) + .block(); + } + catch (Exception e) { + logger.error("Failed to replay message: {}", e.getMessage()); + sseBuilder.error(e); + } + }); + } + catch (Exception e) { + logger.error("Failed to replay messages: {}", e.getMessage()); + sseBuilder.error(e); + } + } + else { + // Establish new listening stream + McpStreamableServerSession.McpStreamableServerSessionStream listeningStream = session + .listeningStream(sessionTransport); + + sseBuilder.onComplete(() -> { + logger.debug("SSE connection completed for session: {}", sessionId); + listeningStream.close(); + }); + } + }, Duration.ZERO); + } + catch (Exception e) { + logger.error("Failed to handle GET request for session {}: {}", sessionId, e.getMessage()); + return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).build(); + } + } + + /** + * Handles POST requests for incoming JSON-RPC messages from clients. + * @param request The incoming server request containing the JSON-RPC message + * @return A ServerResponse indicating success or appropriate error status + */ + private ServerResponse handlePost(ServerRequest request) { + if (this.isClosing) { + return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down"); + } + + List acceptHeaders = request.headers().asHttpHeaders().getAccept(); + if (!acceptHeaders.contains(MediaType.TEXT_EVENT_STREAM) + || !acceptHeaders.contains(MediaType.APPLICATION_JSON)) { + return ServerResponse.badRequest() + .body(new McpError("Invalid Accept headers. Expected TEXT_EVENT_STREAM and APPLICATION_JSON")); + } + + McpTransportContext transportContext = this.contextExtractor.extract(request); + + try { + String body = request.body(String.class); + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(jsonMapper, body); + + // Handle initialization request + if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest + && jsonrpcRequest.method().equals(McpSchema.METHOD_INITIALIZE)) { + McpSchema.InitializeRequest initializeRequest = jsonMapper.convertValue(jsonrpcRequest.params(), + new TypeRef() { + }); + McpStreamableServerSession.McpStreamableServerSessionInit init = this.sessionFactory + .startSession(initializeRequest); + this.sessions.put(init.session().getId(), init.session()); + + try { + McpSchema.InitializeResult initResult = init.initResult().block(); + + return ServerResponse.ok() + .contentType(MediaType.APPLICATION_JSON) + .header(HttpHeaders.MCP_SESSION_ID, init.session().getId()) + .body(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, jsonrpcRequest.id(), initResult, + null)); + } + catch (Exception e) { + logger.error("Failed to initialize session: {}", e.getMessage()); + return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).body(new McpError(e.getMessage())); + } + } + + // Handle other messages that require a session + if (request.headers().header(HttpHeaders.MCP_SESSION_ID).isEmpty()) { + return ServerResponse.badRequest().body(new McpError("Session ID missing")); + } + + String sessionId = request.headers().asHttpHeaders().getFirst(HttpHeaders.MCP_SESSION_ID); + McpStreamableServerSession session = this.sessions.get(sessionId); + + if (session == null) { + return ServerResponse.status(HttpStatus.NOT_FOUND) + .body(new McpError("Session not found: " + sessionId)); + } + + if (message instanceof McpSchema.JSONRPCResponse jsonrpcResponse) { + session.accept(jsonrpcResponse) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) + .block(); + return ServerResponse.accepted().build(); + } + else if (message instanceof McpSchema.JSONRPCNotification jsonrpcNotification) { + session.accept(jsonrpcNotification) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) + .block(); + return ServerResponse.accepted().build(); + } + else if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest) { + // For streaming responses, we need to return SSE + return ServerResponse.sse(sseBuilder -> { + sseBuilder.onComplete(() -> { + logger.debug("Request response stream completed for session: {}", sessionId); + }); + sseBuilder.onTimeout(() -> { + logger.debug("Request response stream timed out for session: {}", sessionId); + }); + + WebMvcStreamableMcpSessionTransport sessionTransport = new WebMvcStreamableMcpSessionTransport( + sessionId, sseBuilder); + + try { + session.responseStream(jsonrpcRequest, sessionTransport) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) + .block(); + } + catch (Exception e) { + logger.error("Failed to handle request stream: {}", e.getMessage()); + sseBuilder.error(e); + } + }, Duration.ZERO); + } + else { + return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR) + .body(new McpError("Unknown message type")); + } + } + catch (IllegalArgumentException | IOException e) { + logger.error("Failed to deserialize message: {}", e.getMessage()); + return ServerResponse.badRequest().body(new McpError("Invalid message format")); + } + catch (Exception e) { + logger.error("Error handling message: {}", e.getMessage()); + return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).body(new McpError(e.getMessage())); + } + } + + /** + * Handles DELETE requests for session deletion. + * @param request The incoming server request + * @return A ServerResponse indicating success or appropriate error status + */ + private ServerResponse handleDelete(ServerRequest request) { + if (this.isClosing) { + return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down"); + } + + if (this.disallowDelete) { + return ServerResponse.status(HttpStatus.METHOD_NOT_ALLOWED).build(); + } + + McpTransportContext transportContext = this.contextExtractor.extract(request); + + if (request.headers().header(HttpHeaders.MCP_SESSION_ID).isEmpty()) { + return ServerResponse.badRequest().body("Session ID required in mcp-session-id header"); + } + + String sessionId = request.headers().asHttpHeaders().getFirst(HttpHeaders.MCP_SESSION_ID); + McpStreamableServerSession session = this.sessions.get(sessionId); + + if (session == null) { + return ServerResponse.notFound().build(); + } + + try { + session.delete().contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)).block(); + this.sessions.remove(sessionId); + return ServerResponse.ok().build(); + } + catch (Exception e) { + logger.error("Failed to delete session {}: {}", sessionId, e.getMessage()); + return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).body(new McpError(e.getMessage())); + } + } + + /** + * Implementation of McpStreamableServerTransport for WebMVC SSE sessions. This class + * handles the transport-level communication for a specific client session. + * + *

+ * This class is thread-safe and uses a ReentrantLock to synchronize access to the + * underlying SSE builder to prevent race conditions when multiple threads attempt to + * send messages concurrently. + */ + private class WebMvcStreamableMcpSessionTransport implements McpStreamableServerTransport { + + private final String sessionId; + + private final SseBuilder sseBuilder; + + private final ReentrantLock lock = new ReentrantLock(); + + private volatile boolean closed = false; + + /** + * Creates a new session transport with the specified ID and SSE builder. + * @param sessionId The unique identifier for this session + * @param sseBuilder The SSE builder for sending server events to the client + */ + WebMvcStreamableMcpSessionTransport(String sessionId, SseBuilder sseBuilder) { + this.sessionId = sessionId; + this.sseBuilder = sseBuilder; + logger.debug("Streamable session transport {} initialized with SSE builder", sessionId); + } + + /** + * Sends a JSON-RPC message to the client through the SSE connection. + * @param message The JSON-RPC message to send + * @return A Mono that completes when the message has been sent + */ + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message) { + return sendMessage(message, null); + } + + /** + * Sends a JSON-RPC message to the client through the SSE connection with a + * specific message ID. + * @param message The JSON-RPC message to send + * @param messageId The message ID for SSE event identification + * @return A Mono that completes when the message has been sent + */ + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message, String messageId) { + return Mono.fromRunnable(() -> { + if (this.closed) { + logger.debug("Attempted to send message to closed session: {}", this.sessionId); + return; + } + + this.lock.lock(); + try { + if (this.closed) { + logger.debug("Session {} was closed during message send attempt", this.sessionId); + return; + } + + String jsonText = jsonMapper.writeValueAsString(message); + this.sseBuilder.id(messageId != null ? messageId : this.sessionId) + .event(MESSAGE_EVENT_TYPE) + .data(jsonText); + logger.debug("Message sent to session {} with ID {}", this.sessionId, messageId); + } + catch (Exception e) { + logger.error("Failed to send message to session {}: {}", this.sessionId, e.getMessage()); + try { + this.sseBuilder.error(e); + } + catch (Exception errorException) { + logger.error("Failed to send error to SSE builder for session {}: {}", this.sessionId, + errorException.getMessage()); + } + } + finally { + this.lock.unlock(); + } + }); + } + + /** + * Converts data from one type to another using the configured McpJsonMapper. + * @param data The source data object to convert + * @param typeRef The target type reference + * @return The converted object of type T + * @param The target type + */ + @Override + public T unmarshalFrom(Object data, TypeRef typeRef) { + return jsonMapper.convertValue(data, typeRef); + } + + /** + * Initiates a graceful shutdown of the transport. + * @return A Mono that completes when the shutdown is complete + */ + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> { + WebMvcStreamableMcpSessionTransport.this.close(); + }); + } + + /** + * Closes the transport immediately. + */ + @Override + public void close() { + this.lock.lock(); + try { + if (this.closed) { + logger.debug("Session transport {} already closed", this.sessionId); + return; + } + + this.closed = true; + + this.sseBuilder.complete(); + logger.debug("Successfully completed SSE builder for session {}", sessionId); + } + catch (Exception e) { + logger.warn("Failed to complete SSE builder for session {}: {}", sessionId, e.getMessage()); + } + finally { + this.lock.unlock(); + } + } + + } + + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for creating instances of {@link WebMvcStreamableServerTransportProvider}. + */ + public static class Builder { + + private McpJsonMapper jsonMapper; + + private String mcpEndpoint = "/mcp"; + + private boolean disallowDelete = false; + + private McpTransportContextExtractor contextExtractor = ( + serverRequest) -> McpTransportContext.EMPTY; + + private Duration keepAliveInterval; + + /** + * Sets the McpJsonMapper to use for JSON serialization/deserialization of MCP + * messages. + * @param jsonMapper The McpJsonMapper instance. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if jsonMapper is null + */ + public Builder jsonMapper(McpJsonMapper jsonMapper) { + Assert.notNull(jsonMapper, "McpJsonMapper must not be null"); + this.jsonMapper = jsonMapper; + return this; + } + + /** + * Sets the endpoint URI where clients should send their JSON-RPC messages. + * @param mcpEndpoint The MCP endpoint URI. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if mcpEndpoint is null + */ + public Builder mcpEndpoint(String mcpEndpoint) { + Assert.notNull(mcpEndpoint, "MCP endpoint must not be null"); + this.mcpEndpoint = mcpEndpoint; + return this; + } + + /** + * Sets whether to disallow DELETE requests on the endpoint. + * @param disallowDelete true to disallow DELETE requests, false otherwise + * @return this builder instance + */ + public Builder disallowDelete(boolean disallowDelete) { + this.disallowDelete = disallowDelete; + return this; + } + + /** + * Sets the context extractor that allows providing the MCP feature + * implementations to inspect HTTP transport level metadata that was present at + * HTTP request processing time. This allows to extract custom headers and other + * useful data for use during execution later on in the process. + * @param contextExtractor The contextExtractor to fill in a + * {@link McpTransportContext}. + * @return this builder instance + * @throws IllegalArgumentException if contextExtractor is null + */ + public Builder contextExtractor(McpTransportContextExtractor contextExtractor) { + Assert.notNull(contextExtractor, "contextExtractor must not be null"); + this.contextExtractor = contextExtractor; + return this; + } + + /** + * Sets the keep-alive interval for the transport. If set, a keep-alive scheduler + * will be created to periodically check and send keep-alive messages to clients. + * @param keepAliveInterval The interval duration for keep-alive messages, or null + * to disable keep-alive + * @return this builder instance + */ + public Builder keepAliveInterval(Duration keepAliveInterval) { + this.keepAliveInterval = keepAliveInterval; + return this; + } + + /** + * Builds a new instance of {@link WebMvcStreamableServerTransportProvider} with + * the configured settings. + * @return A new WebMvcStreamableServerTransportProvider instance + * @throws IllegalStateException if required parameters are not set + */ + public WebMvcStreamableServerTransportProvider build() { + Assert.notNull(this.mcpEndpoint, "MCP endpoint must be set"); + return new WebMvcStreamableServerTransportProvider( + jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, mcpEndpoint, disallowDelete, + contextExtractor, keepAliveInterval); + } + + } + +} diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/common/McpTransportContextIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/common/McpTransportContextIntegrationTests.java new file mode 100644 index 000000000..cc9945436 --- /dev/null +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/common/McpTransportContextIntegrationTests.java @@ -0,0 +1,297 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.common; + +import java.util.Map; +import java.util.function.BiFunction; +import java.util.function.Supplier; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.McpSyncClient; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.client.transport.customizer.McpSyncHttpClientRequestCustomizer; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.McpServerFeatures; +import io.modelcontextprotocol.server.McpStatelessServerFeatures; +import io.modelcontextprotocol.server.McpStatelessSyncServer; +import io.modelcontextprotocol.server.McpSyncServer; +import io.modelcontextprotocol.server.McpSyncServerExchange; +import io.modelcontextprotocol.server.McpTransportContextExtractor; +import io.modelcontextprotocol.server.TestUtil; +import io.modelcontextprotocol.server.TomcatTestUtil; +import io.modelcontextprotocol.server.TomcatTestUtil.TomcatServer; +import io.modelcontextprotocol.server.transport.WebMvcSseServerTransportProvider; +import io.modelcontextprotocol.server.transport.WebMvcStatelessServerTransport; +import io.modelcontextprotocol.server.transport.WebMvcStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.McpSchema; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.web.servlet.config.annotation.EnableWebMvc; +import org.springframework.web.servlet.function.RouterFunction; +import org.springframework.web.servlet.function.ServerRequest; +import org.springframework.web.servlet.function.ServerResponse; +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for {@link McpTransportContext} propagation between MCP clients and + * servers using Spring WebMVC transport implementations. + * + *

+ * This test class validates the end-to-end flow of transport context propagation across + * different MCP transport mechanisms in a Spring WebMVC environment. It demonstrates how + * contextual information can be passed from client to server through HTTP headers and + * properly extracted and utilized on the server side. + * + *

Transport Types Tested

+ *
    + *
  • Stateless: Tests context propagation with + * {@link WebMvcStatelessServerTransport} where each request is independent
  • + *
  • Streamable HTTP: Tests context propagation with + * {@link WebMvcStreamableServerTransportProvider} supporting stateful server + * sessions
  • + *
  • Server-Sent Events (SSE): Tests context propagation with + * {@link WebMvcSseServerTransportProvider} for long-lived connections
  • + *
+ * + * @author Daniel Garnier-Moiroux + * @author Christian Tzolov + */ +@Timeout(15) +public class McpTransportContextIntegrationTests { + + private static final int PORT = TestUtil.findAvailablePort(); + + private TomcatServer tomcatServer; + + private static final ThreadLocal CLIENT_SIDE_HEADER_VALUE_HOLDER = new ThreadLocal<>(); + + private static final String HEADER_NAME = "x-test"; + + private final Supplier clientContextProvider = () -> { + var headerValue = CLIENT_SIDE_HEADER_VALUE_HOLDER.get(); + return headerValue != null ? McpTransportContext.create(Map.of("client-side-header-value", headerValue)) + : McpTransportContext.EMPTY; + }; + + private final McpSyncHttpClientRequestCustomizer clientRequestCustomizer = (builder, method, endpoint, body, + context) -> { + var headerValue = context.get("client-side-header-value"); + if (headerValue != null) { + builder.header(HEADER_NAME, headerValue.toString()); + } + }; + + private static final BiFunction statelessHandler = ( + transportContext, + request) -> new McpSchema.CallToolResult(transportContext.get("server-side-header-value").toString(), null); + + private static final BiFunction statefulHandler = ( + exchange, request) -> statelessHandler.apply(exchange.transportContext(), request); + + private static McpTransportContextExtractor serverContextExtractor = (ServerRequest r) -> { + String headerValue = r.servletRequest().getHeader(HEADER_NAME); + return headerValue != null ? McpTransportContext.create(Map.of("server-side-header-value", headerValue)) + : McpTransportContext.EMPTY; + }; + + private final McpSyncClient streamableClient = McpClient + .sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT) + .httpRequestCustomizer(clientRequestCustomizer) + .build()) + .transportContextProvider(clientContextProvider) + .build(); + + private final McpSyncClient sseClient = McpClient + .sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT) + .httpRequestCustomizer(clientRequestCustomizer) + .build()) + .transportContextProvider(clientContextProvider) + .build(); + + private static final McpSchema.Tool tool = McpSchema.Tool.builder() + .name("test-tool") + .description("return the value of the x-test header from call tool request") + .build(); + + @AfterEach + public void after() { + CLIENT_SIDE_HEADER_VALUE_HOLDER.remove(); + if (streamableClient != null) { + streamableClient.closeGracefully(); + } + if (sseClient != null) { + sseClient.closeGracefully(); + } + stopTomcat(); + } + + @Test + void statelessServer() { + startTomcat(TestStatelessConfig.class); + + McpSchema.InitializeResult initResult = streamableClient.initialize(); + assertThat(initResult).isNotNull(); + + CLIENT_SIDE_HEADER_VALUE_HOLDER.set("some important value"); + McpSchema.CallToolResult response = streamableClient + .callTool(new McpSchema.CallToolRequest("test-tool", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + } + + @Test + void streamableServer() { + + startTomcat(TestStreamableHttpConfig.class); + + McpSchema.InitializeResult initResult = streamableClient.initialize(); + assertThat(initResult).isNotNull(); + + CLIENT_SIDE_HEADER_VALUE_HOLDER.set("some important value"); + McpSchema.CallToolResult response = streamableClient + .callTool(new McpSchema.CallToolRequest("test-tool", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + } + + @Test + void sseServer() { + startTomcat(TestSseConfig.class); + + McpSchema.InitializeResult initResult = sseClient.initialize(); + assertThat(initResult).isNotNull(); + + CLIENT_SIDE_HEADER_VALUE_HOLDER.set("some important value"); + McpSchema.CallToolResult response = sseClient.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + } + + private void startTomcat(Class componentClass) { + tomcatServer = TomcatTestUtil.createTomcatServer("", PORT, componentClass); + try { + tomcatServer.tomcat().start(); + assertThat(tomcatServer.tomcat().getServer().getState()).isEqualTo(LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + } + + private void stopTomcat() { + if (tomcatServer != null && tomcatServer.tomcat() != null) { + try { + tomcatServer.tomcat().stop(); + tomcatServer.tomcat().destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + + @Configuration + @EnableWebMvc + static class TestStatelessConfig { + + @Bean + public WebMvcStatelessServerTransport webMvcStatelessServerTransport() { + + return WebMvcStatelessServerTransport.builder().contextExtractor(serverContextExtractor).build(); + } + + @Bean + public RouterFunction routerFunction(WebMvcStatelessServerTransport transportProvider) { + return transportProvider.getRouterFunction(); + } + + @Bean + public McpStatelessSyncServer mcpStatelessServer(WebMvcStatelessServerTransport transportProvider) { + return McpServer.sync(transportProvider) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpStatelessServerFeatures.SyncToolSpecification(tool, statelessHandler)) + .build(); + } + + } + + @Configuration + @EnableWebMvc + static class TestStreamableHttpConfig { + + @Bean + public WebMvcStreamableServerTransportProvider webMvcStreamableServerTransport() { + + return WebMvcStreamableServerTransportProvider.builder().contextExtractor(serverContextExtractor).build(); + } + + @Bean + public RouterFunction routerFunction( + WebMvcStreamableServerTransportProvider transportProvider) { + return transportProvider.getRouterFunction(); + } + + @Bean + public McpSyncServer mcpStreamableServer(WebMvcStreamableServerTransportProvider transportProvider) { + return McpServer.sync(transportProvider) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpServerFeatures.SyncToolSpecification(tool, null, statefulHandler)) + .build(); + } + + } + + @Configuration + @EnableWebMvc + static class TestSseConfig { + + @Bean + public WebMvcSseServerTransportProvider webMvcSseServerTransport() { + + return WebMvcSseServerTransportProvider.builder() + .contextExtractor(serverContextExtractor) + .messageEndpoint("/mcp/message") + .build(); + } + + @Bean + public RouterFunction routerFunction(WebMvcSseServerTransportProvider transportProvider) { + return transportProvider.getRouterFunction(); + } + + @Bean + public McpSyncServer mcpSseServer(WebMvcSseServerTransportProvider transportProvider) { + return McpServer.sync(transportProvider) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpServerFeatures.SyncToolSpecification(tool, null, statefulHandler)) + .build(); + + } + + } + +} diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/TomcatTestUtil.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/TomcatTestUtil.java index ccf9e2d77..8625b6a70 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/TomcatTestUtil.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/TomcatTestUtil.java @@ -3,10 +3,6 @@ */ package io.modelcontextprotocol.server; -import java.io.IOException; -import java.net.InetSocketAddress; -import java.net.ServerSocket; - import org.apache.catalina.Context; import org.apache.catalina.startup.Tomcat; diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMcpStreamableAsyncServerTransportTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMcpStreamableAsyncServerTransportTests.java new file mode 100644 index 000000000..36aaa27fb --- /dev/null +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMcpStreamableAsyncServerTransportTests.java @@ -0,0 +1,117 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import org.apache.catalina.Context; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.startup.Tomcat; +import org.junit.jupiter.api.Timeout; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; +import org.springframework.web.servlet.DispatcherServlet; +import org.springframework.web.servlet.config.annotation.EnableWebMvc; +import org.springframework.web.servlet.function.RouterFunction; +import org.springframework.web.servlet.function.ServerResponse; + +import io.modelcontextprotocol.server.transport.WebMvcStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; +import reactor.netty.DisposableServer; + +/** + * Tests for {@link McpAsyncServer} using {@link WebMvcSseServerTransportProvider}. + * + * @author Christian Tzolov + */ +@Timeout(15) // Giving extra time beyond the client timeout +class WebMcpStreamableAsyncServerTransportTests extends AbstractMcpAsyncServerTests { + + private static final int PORT = TestUtil.findAvailablePort(); + + private static final String MCP_ENDPOINT = "/mcp"; + + private DisposableServer httpServer; + + private AnnotationConfigWebApplicationContext appContext; + + private Tomcat tomcat; + + private McpStreamableServerTransportProvider transportProvider; + + @Configuration + @EnableWebMvc + static class TestConfig { + + @Bean + public WebMvcStreamableServerTransportProvider webMvcSseServerTransportProvider() { + return WebMvcStreamableServerTransportProvider.builder().mcpEndpoint(MCP_ENDPOINT).build(); + } + + @Bean + public RouterFunction routerFunction( + WebMvcStreamableServerTransportProvider transportProvider) { + return transportProvider.getRouterFunction(); + } + + } + + private McpStreamableServerTransportProvider createMcpTransportProvider() { + // Set up Tomcat first + tomcat = new Tomcat(); + tomcat.setPort(PORT); + + // Set Tomcat base directory to java.io.tmpdir to avoid permission issues + String baseDir = System.getProperty("java.io.tmpdir"); + tomcat.setBaseDir(baseDir); + + // Use the same directory for document base + Context context = tomcat.addContext("", baseDir); + + // Create and configure Spring WebMvc context + appContext = new AnnotationConfigWebApplicationContext(); + appContext.register(TestConfig.class); + appContext.setServletContext(context.getServletContext()); + appContext.refresh(); + + // Get the transport from Spring context + transportProvider = appContext.getBean(McpStreamableServerTransportProvider.class); + + // Create DispatcherServlet with our Spring context + DispatcherServlet dispatcherServlet = new DispatcherServlet(appContext); + + // Add servlet to Tomcat and get the wrapper + var wrapper = Tomcat.addServlet(context, "dispatcherServlet", dispatcherServlet); + wrapper.setLoadOnStartup(1); + context.addServletMappingDecoded("/*", "dispatcherServlet"); + + try { + tomcat.start(); + tomcat.getConnector(); // Create and start the connector + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + + return transportProvider; + } + + @Override + protected McpServer.AsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(createMcpTransportProvider()); + } + + @Override + protected void onStart() { + } + + @Override + protected void onClose() { + if (httpServer != null) { + httpServer.disposeNow(); + } + } + +} diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMcpStreamableSyncServerTransportTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMcpStreamableSyncServerTransportTests.java new file mode 100644 index 000000000..2f75551eb --- /dev/null +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMcpStreamableSyncServerTransportTests.java @@ -0,0 +1,116 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import org.apache.catalina.Context; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.startup.Tomcat; +import org.junit.jupiter.api.Timeout; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; +import org.springframework.web.servlet.DispatcherServlet; +import org.springframework.web.servlet.config.annotation.EnableWebMvc; +import org.springframework.web.servlet.function.RouterFunction; +import org.springframework.web.servlet.function.ServerResponse; + +import io.modelcontextprotocol.server.transport.WebMvcStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; +import reactor.netty.DisposableServer; + +/** + * Tests for {@link McpAsyncServer} using {@link WebMvcSseServerTransportProvider}. + * + * @author Christian Tzolov + */ +@Timeout(15) // Giving extra time beyond the client timeout +class WebMcpStreamableSyncServerTransportTests extends AbstractMcpSyncServerTests { + + private static final int PORT = TestUtil.findAvailablePort(); + + private static final String MCP_ENDPOINT = "/mcp"; + + private DisposableServer httpServer; + + private AnnotationConfigWebApplicationContext appContext; + + private Tomcat tomcat; + + private McpStreamableServerTransportProvider transportProvider; + + @Configuration + @EnableWebMvc + static class TestConfig { + + @Bean + public WebMvcStreamableServerTransportProvider webMvcSseServerTransportProvider() { + return WebMvcStreamableServerTransportProvider.builder().mcpEndpoint(MCP_ENDPOINT).build(); + } + + @Bean + public RouterFunction routerFunction( + WebMvcStreamableServerTransportProvider transportProvider) { + return transportProvider.getRouterFunction(); + } + + } + + private McpStreamableServerTransportProvider createMcpTransportProvider() { + // Set up Tomcat first + tomcat = new Tomcat(); + tomcat.setPort(PORT); + + // Set Tomcat base directory to java.io.tmpdir to avoid permission issues + String baseDir = System.getProperty("java.io.tmpdir"); + tomcat.setBaseDir(baseDir); + + // Use the same directory for document base + Context context = tomcat.addContext("", baseDir); + + // Create and configure Spring WebMvc context + appContext = new AnnotationConfigWebApplicationContext(); + appContext.register(TestConfig.class); + appContext.setServletContext(context.getServletContext()); + appContext.refresh(); + + // Get the transport from Spring context + transportProvider = appContext.getBean(McpStreamableServerTransportProvider.class); + + // Create DispatcherServlet with our Spring context + DispatcherServlet dispatcherServlet = new DispatcherServlet(appContext); + + // Add servlet to Tomcat and get the wrapper + var wrapper = Tomcat.addServlet(context, "dispatcherServlet", dispatcherServlet); + wrapper.setLoadOnStartup(1); + context.addServletMappingDecoded("/*", "dispatcherServlet"); + + try { + tomcat.start(); + tomcat.getConnector(); // Create and start the connector + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + + return transportProvider; + } + + @Override + protected McpServer.SyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(createMcpTransportProvider()); + } + + @Override + protected void onStart() { + } + + @Override + protected void onClose() { + if (httpServer != null) { + httpServer.disposeNow(); + } + } + +} diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportTests.java index 6a6ad17e9..ccf3170c9 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportTests.java @@ -4,7 +4,6 @@ package io.modelcontextprotocol.server; -import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.server.transport.WebMvcSseServerTransportProvider; import io.modelcontextprotocol.spec.McpServerTransportProvider; import org.apache.catalina.Context; @@ -37,7 +36,10 @@ static class TestConfig { @Bean public WebMvcSseServerTransportProvider webMvcSseServerTransportProvider() { - return new WebMvcSseServerTransportProvider(new ObjectMapper(), MESSAGE_ENDPOINT); + return WebMvcSseServerTransportProvider.builder() + .messageEndpoint(MESSAGE_ENDPOINT) + .sseEndpoint(WebMvcSseServerTransportProvider.DEFAULT_SSE_ENDPOINT) + .build(); } @Bean @@ -49,8 +51,7 @@ public RouterFunction routerFunction(WebMvcSseServerTransportPro private AnnotationConfigWebApplicationContext appContext; - @Override - protected McpServerTransportProvider createMcpTransportProvider() { + private McpServerTransportProvider createMcpTransportProvider() { // Set up Tomcat first tomcat = new Tomcat(); tomcat.setPort(PORT); @@ -90,6 +91,11 @@ protected McpServerTransportProvider createMcpTransportProvider() { return transportProvider; } + @Override + protected McpServer.AsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(createMcpTransportProvider()); + } + @Override protected void onStart() { } diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomContextPathTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomContextPathTests.java index 1b5218cc5..d8d26af48 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomContextPathTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomContextPathTests.java @@ -3,7 +3,6 @@ */ package io.modelcontextprotocol.server; -import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.client.McpClient; import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; import io.modelcontextprotocol.server.transport.WebMvcSseServerTransportProvider; @@ -91,8 +90,14 @@ static class TestConfig { @Bean public WebMvcSseServerTransportProvider webMvcSseServerTransportProvider() { - return new WebMvcSseServerTransportProvider(new ObjectMapper(), CUSTOM_CONTEXT_PATH, MESSAGE_ENDPOINT, - WebMvcSseServerTransportProvider.DEFAULT_SSE_ENDPOINT); + return WebMvcSseServerTransportProvider.builder() + .baseUrl(CUSTOM_CONTEXT_PATH) + .messageEndpoint(MESSAGE_ENDPOINT) + .sseEndpoint(WebMvcSseServerTransportProvider.DEFAULT_SSE_ENDPOINT) + .build(); + // return new WebMvcSseServerTransportProvider(new ObjectMapper(), + // CUSTOM_CONTEXT_PATH, MESSAGE_ENDPOINT, + // WebMvcSseServerTransportProvider.DEFAULT_SSE_ENDPOINT); } @Bean diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java index 3f3f7be62..045f9b3dd 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java @@ -3,50 +3,39 @@ */ package io.modelcontextprotocol.server; +import static org.assertj.core.api.Assertions.assertThat; + import java.time.Duration; -import java.util.List; import java.util.Map; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Function; +import java.util.stream.Stream; -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.client.McpClient; -import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; -import io.modelcontextprotocol.server.transport.WebMvcSseServerTransportProvider; -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpSchema.CallToolResult; -import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; -import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; -import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; -import io.modelcontextprotocol.spec.McpSchema.InitializeResult; -import io.modelcontextprotocol.spec.McpSchema.ModelPreferences; -import io.modelcontextprotocol.spec.McpSchema.Role; -import io.modelcontextprotocol.spec.McpSchema.Root; -import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; -import io.modelcontextprotocol.spec.McpSchema.Tool; import org.apache.catalina.LifecycleException; import org.apache.catalina.LifecycleState; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import reactor.core.publisher.Mono; -import reactor.test.StepVerifier; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.provider.Arguments; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; -import org.springframework.web.client.RestClient; +import org.springframework.web.reactive.function.client.WebClient; import org.springframework.web.servlet.config.annotation.EnableWebMvc; import org.springframework.web.servlet.function.RouterFunction; +import org.springframework.web.servlet.function.ServerRequest; import org.springframework.web.servlet.function.ServerResponse; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatExceptionOfType; -import static org.awaitility.Awaitility.await; -import static org.mockito.Mockito.mock; +import io.modelcontextprotocol.AbstractMcpClientServerIntegrationTests; +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.server.McpServer.AsyncSpecification; +import io.modelcontextprotocol.server.McpServer.SingleSessionSyncSpecification; +import io.modelcontextprotocol.server.transport.WebMvcSseServerTransportProvider; +import reactor.core.scheduler.Schedulers; -class WebMvcSseIntegrationTests { +@Timeout(15) +class WebMvcSseIntegrationTests extends AbstractMcpClientServerIntegrationTests { private static final int PORT = TestUtil.findAvailablePort(); @@ -54,7 +43,24 @@ class WebMvcSseIntegrationTests { private WebMvcSseServerTransportProvider mcpServerTransportProvider; - McpClient.SyncSpec clientBuilder; + static McpTransportContextExtractor TEST_CONTEXT_EXTRACTOR = r -> McpTransportContext + .create(Map.of("important", "value")); + + static Stream clientsForTesting() { + return Stream.of(Arguments.of("httpclient"), Arguments.of("webflux")); + } + + @Override + protected void prepareClients(int port, String mcpEndpoint) { + + clientBuilders.put("httpclient", + McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + port).build()) + .requestTimeout(Duration.ofHours(10))); + + clientBuilders.put("webflux", McpClient + .sync(WebFluxSseClientTransport.builder(WebClient.builder().baseUrl("http://localhost:" + port)).build()) + .requestTimeout(Duration.ofHours(10))); + } @Configuration @EnableWebMvc @@ -62,7 +68,10 @@ static class TestConfig { @Bean public WebMvcSseServerTransportProvider webMvcSseServerTransportProvider() { - return new WebMvcSseServerTransportProvider(new ObjectMapper(), MESSAGE_ENDPOINT); + return WebMvcSseServerTransportProvider.builder() + .messageEndpoint(MESSAGE_ENDPOINT) + .contextExtractor(TEST_CONTEXT_EXTRACTOR) + .build(); } @Bean @@ -87,7 +96,7 @@ public void before() { throw new RuntimeException("Failed to start Tomcat", e); } - clientBuilder = McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT).build()); + prepareClients(PORT, MESSAGE_ENDPOINT); // Get the transport from Spring context mcpServerTransportProvider = tomcatServer.appContext().getBean(WebMvcSseServerTransportProvider.class); @@ -96,9 +105,11 @@ public void before() { @AfterEach public void after() { + reactor.netty.http.HttpResources.disposeLoopsAndConnections(); if (mcpServerTransportProvider != null) { mcpServerTransportProvider.closeGracefully().block(); } + Schedulers.shutdownNow(); if (tomcatServer.appContext() != null) { tomcatServer.appContext().close(); } @@ -113,753 +124,14 @@ public void after() { } } - // --------------------------------------- - // Sampling Tests - // --------------------------------------- - @Test - void testCreateMessageWithoutSamplingCapabilities() { - - McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { - - exchange.createMessage(mock(McpSchema.CreateMessageRequest.class)).block(); - - return Mono.just(mock(CallToolResult.class)); - }); - - //@formatter:off - var server = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .tools(tool) - .build(); - - try ( - // Create client without sampling capabilities - var client = clientBuilder - .clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) - .build()) {//@formatter:on - - assertThat(client.initialize()).isNotNull(); - - try { - client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - } - catch (McpError e) { - assertThat(e).isInstanceOf(McpError.class) - .hasMessage("Client must be configured with sampling capabilities"); - } - } - server.close(); - } - - @Test - void testCreateMessageSuccess() { - - Function samplingHandler = request -> { - assertThat(request.messages()).hasSize(1); - assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); - - return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", - CreateMessageResult.StopReason.STOP_SEQUENCE); - }; - - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); - - McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { - - var createMessageRequest = McpSchema.CreateMessageRequest.builder() - .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, - new McpSchema.TextContent("Test message")))) - .modelPreferences(ModelPreferences.builder() - .hints(List.of()) - .costPriority(1.0) - .speedPriority(1.0) - .intelligencePriority(1.0) - .build()) - .build(); - - StepVerifier.create(exchange.createMessage(createMessageRequest)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.role()).isEqualTo(Role.USER); - assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); - assertThat(result.model()).isEqualTo("MockModelName"); - assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); - }).verifyComplete(); - - return Mono.just(callResponse); - }); - - //@formatter:off - var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .tools(tool) - .build(); - - try ( - var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().sampling().build()) - .sampling(samplingHandler) - .build()) {//@formatter:on - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(response).isNotNull().isEqualTo(callResponse); - } - mcpServer.close(); - } - - @Test - void testCreateMessageWithRequestTimeoutSuccess() throws InterruptedException { - - // Client - - Function samplingHandler = request -> { - assertThat(request.messages()).hasSize(1); - assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); - try { - TimeUnit.SECONDS.sleep(2); - } - catch (InterruptedException e) { - throw new RuntimeException(e); - } - return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", - CreateMessageResult.StopReason.STOP_SEQUENCE); - }; - - var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().sampling().build()) - .sampling(samplingHandler) - .build(); - - // Server - - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); - - McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { - - var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() - .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, - new McpSchema.TextContent("Test message")))) - .modelPreferences(ModelPreferences.builder() - .hints(List.of()) - .costPriority(1.0) - .speedPriority(1.0) - .intelligencePriority(1.0) - .build()) - .build(); - - StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.role()).isEqualTo(Role.USER); - assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); - assertThat(result.model()).isEqualTo("MockModelName"); - assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); - }).verifyComplete(); - - return Mono.just(callResponse); - }); - - var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .requestTimeout(Duration.ofSeconds(4)) - .tools(tool) - .build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); - - mcpClient.close(); - mcpServer.close(); - } - - @Test - void testCreateMessageWithRequestTimeoutFail() throws InterruptedException { - - // Client - - Function samplingHandler = request -> { - assertThat(request.messages()).hasSize(1); - assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); - try { - TimeUnit.SECONDS.sleep(2); - } - catch (InterruptedException e) { - throw new RuntimeException(e); - } - return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", - CreateMessageResult.StopReason.STOP_SEQUENCE); - }; - - var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().sampling().build()) - .sampling(samplingHandler) - .build(); - - // Server - - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); - - McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { - - var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() - .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, - new McpSchema.TextContent("Test message")))) - .modelPreferences(ModelPreferences.builder() - .hints(List.of()) - .costPriority(1.0) - .speedPriority(1.0) - .intelligencePriority(1.0) - .build()) - .build(); - - StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.role()).isEqualTo(Role.USER); - assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); - assertThat(result.model()).isEqualTo("MockModelName"); - assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); - }).verifyComplete(); - - return Mono.just(callResponse); - }); - - var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .requestTimeout(Duration.ofSeconds(1)) - .tools(tool) - .build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThatExceptionOfType(McpError.class).isThrownBy(() -> { - mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - }).withMessageContaining("Timeout"); - - mcpClient.close(); - mcpServer.close(); - } - - // --------------------------------------- - // Elicitation Tests - // --------------------------------------- - @Test - void testCreateElicitationWithoutElicitationCapabilities() { - - McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { - - exchange.createElicitation(mock(McpSchema.ElicitRequest.class)).block(); - - return Mono.just(mock(CallToolResult.class)); - }); - - var server = McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").tools(tool).build(); - - try ( - // Create client without elicitation capabilities - var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")).build()) { - - assertThat(client.initialize()).isNotNull(); - - try { - client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - } - catch (McpError e) { - assertThat(e).isInstanceOf(McpError.class) - .hasMessage("Client must be configured with elicitation capabilities"); - } - } - server.closeGracefully().block(); - } - - @Test - void testCreateElicitationSuccess() { - - Function elicitationHandler = request -> { - assertThat(request.message()).isNotEmpty(); - assertThat(request.requestedSchema()).isNotNull(); - - return new McpSchema.ElicitResult(McpSchema.ElicitResult.Action.ACCEPT, - Map.of("message", request.message())); - }; - - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); - - McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { - - var elicitationRequest = McpSchema.ElicitRequest.builder() - .message("Test message") - .requestedSchema( - Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) - .build(); - - StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); - assertThat(result.content().get("message")).isEqualTo("Test message"); - }).verifyComplete(); - - return Mono.just(callResponse); - }); - - var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .tools(tool) - .build(); - - try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().elicitation().build()) - .elicitation(elicitationHandler) - .build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); - } - mcpServer.closeGracefully().block(); - } - - @Test - void testCreateElicitationWithRequestTimeoutSuccess() { - - // Client - - Function elicitationHandler = request -> { - assertThat(request.message()).isNotEmpty(); - assertThat(request.requestedSchema()).isNotNull(); - try { - TimeUnit.SECONDS.sleep(2); - } - catch (InterruptedException e) { - throw new RuntimeException(e); - } - return new McpSchema.ElicitResult(McpSchema.ElicitResult.Action.ACCEPT, - Map.of("message", request.message())); - }; - - var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().elicitation().build()) - .elicitation(elicitationHandler) - .build(); - - // Server - - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); - - McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { - - var elicitationRequest = McpSchema.ElicitRequest.builder() - .message("Test message") - .requestedSchema( - Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) - .build(); - - StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); - assertThat(result.content().get("message")).isEqualTo("Test message"); - }).verifyComplete(); - - return Mono.just(callResponse); - }); - - var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .requestTimeout(Duration.ofSeconds(3)) - .tools(tool) - .build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); - - mcpClient.closeGracefully(); - mcpServer.closeGracefully().block(); - } - - @Test - void testCreateElicitationWithRequestTimeoutFail() { - - // Client - - Function elicitationHandler = request -> { - assertThat(request.message()).isNotEmpty(); - assertThat(request.requestedSchema()).isNotNull(); - try { - TimeUnit.SECONDS.sleep(2); - } - catch (InterruptedException e) { - throw new RuntimeException(e); - } - return new McpSchema.ElicitResult(McpSchema.ElicitResult.Action.ACCEPT, - Map.of("message", request.message())); - }; - - var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().elicitation().build()) - .elicitation(elicitationHandler) - .build(); - - // Server - - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); - - McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { - - var elicitationRequest = McpSchema.ElicitRequest.builder() - .message("Test message") - .requestedSchema( - Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) - .build(); - - StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); - assertThat(result.content().get("message")).isEqualTo("Test message"); - }).verifyComplete(); - - return Mono.just(callResponse); - }); - - var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .requestTimeout(Duration.ofSeconds(1)) - .tools(tool) - .build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThatExceptionOfType(McpError.class).isThrownBy(() -> { - mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - }).withMessageContaining("Timeout"); - - mcpClient.closeGracefully(); - mcpServer.closeGracefully().block(); - } - - // --------------------------------------- - // Roots Tests - // --------------------------------------- - @Test - void testRootsSuccess() { - List roots = List.of(new Root("uri1://", "root1"), new Root("uri2://", "root2")); - - AtomicReference> rootsRef = new AtomicReference<>(); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) - .build(); - - try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) - .roots(roots) - .build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThat(rootsRef.get()).isNull(); - - mcpClient.rootsListChangedNotification(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(roots); - }); - - // Remove a root - mcpClient.removeRoot(roots.get(0).uri()); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(roots.get(1))); - }); - - // Add a new root - var root3 = new Root("uri3://", "root3"); - mcpClient.addRoot(root3); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(roots.get(1), root3)); - }); - } - - mcpServer.close(); - } - - @Test - void testRootsWithoutCapability() { - - McpServerFeatures.SyncToolSpecification tool = new McpServerFeatures.SyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { - - exchange.listRoots(); // try to list roots - - return mock(CallToolResult.class); - }); - - var mcpServer = McpServer.sync(mcpServerTransportProvider).rootsChangeHandler((exchange, rootsUpdate) -> { - }).tools(tool).build(); - - try ( - // Create client without roots capability - // No roots capability - var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()).build()) { - - assertThat(mcpClient.initialize()).isNotNull(); - - // Attempt to list roots should fail - try { - mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - } - catch (McpError e) { - assertThat(e).isInstanceOf(McpError.class).hasMessage("Roots not supported"); - } - } - - mcpServer.close(); - } - - @Test - void testRootsNotificationWithEmptyRootsList() { - AtomicReference> rootsRef = new AtomicReference<>(); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) - .build(); - - try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) - .roots(List.of()) // Empty roots list - .build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - mcpClient.rootsListChangedNotification(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).isEmpty(); - }); - } - - mcpServer.close(); - } - - @Test - void testRootsWithMultipleHandlers() { - List roots = List.of(new Root("uri1://", "root1")); - - AtomicReference> rootsRef1 = new AtomicReference<>(); - AtomicReference> rootsRef2 = new AtomicReference<>(); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef1.set(rootsUpdate)) - .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef2.set(rootsUpdate)) - .build(); - - try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) - .roots(roots) - .build()) { - - assertThat(mcpClient.initialize()).isNotNull(); - - mcpClient.rootsListChangedNotification(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef1.get()).containsAll(roots); - assertThat(rootsRef2.get()).containsAll(roots); - }); - } - - mcpServer.close(); - } - - @Test - void testRootsServerCloseWithActiveSubscription() { - List roots = List.of(new Root("uri1://", "root1")); - - AtomicReference> rootsRef = new AtomicReference<>(); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) - .build(); - - try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) - .roots(roots) - .build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - mcpClient.rootsListChangedNotification(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(roots); - }); - } - - mcpServer.close(); - } - - // --------------------------------------- - // Tools Tests - // --------------------------------------- - - String emptyJsonSchema = """ - { - "$schema": "http://json-schema.org/draft-07/schema#", - "type": "object", - "properties": {} - } - """; - - @Test - void testToolCallSuccess() { - - var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); - McpServerFeatures.SyncToolSpecification tool1 = new McpServerFeatures.SyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { - // perform a blocking call to a remote service - String response = RestClient.create() - .get() - .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md") - .retrieve() - .body(String.class); - assertThat(response).isNotBlank(); - return callResponse; - }); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool1) - .build(); - - try (var mcpClient = clientBuilder.build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); - - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(response).isNotNull().isEqualTo(callResponse); - } - - mcpServer.close(); - } - - @Test - void testToolListChangeHandlingSuccess() { - - var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); - McpServerFeatures.SyncToolSpecification tool1 = new McpServerFeatures.SyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { - // perform a blocking call to a remote service - String response = RestClient.create() - .get() - .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md") - .retrieve() - .body(String.class); - assertThat(response).isNotBlank(); - return callResponse; - }); - - AtomicReference> rootsRef = new AtomicReference<>(); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool1) - .build(); - - try (var mcpClient = clientBuilder.toolsChangeConsumer(toolsUpdate -> { - // perform a blocking call to a remote service - String response = RestClient.create() - .get() - .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md") - .retrieve() - .body(String.class); - assertThat(response).isNotBlank(); - rootsRef.set(toolsUpdate); - }).build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThat(rootsRef.get()).isNull(); - - assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); - - mcpServer.notifyToolsListChanged(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(tool1.tool())); - }); - - // Remove a tool - mcpServer.removeTool("tool1"); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).isEmpty(); - }); - - // Add a new tool - McpServerFeatures.SyncToolSpecification tool2 = new McpServerFeatures.SyncToolSpecification( - new McpSchema.Tool("tool2", "tool2 description", emptyJsonSchema), - (exchange, request) -> callResponse); - - mcpServer.addTool(tool2); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(tool2.tool())); - }); - } - - mcpServer.close(); + @Override + protected AsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(mcpServerTransportProvider); } - @Test - void testInitialize() { - - var mcpServer = McpServer.sync(mcpServerTransportProvider).build(); - - try (var mcpClient = clientBuilder.build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - } - - mcpServer.close(); + @Override + protected SingleSessionSyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(mcpServerTransportProvider); } } diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseSyncServerTransportTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseSyncServerTransportTests.java index 1964703c1..66d6d3ae9 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseSyncServerTransportTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseSyncServerTransportTests.java @@ -4,7 +4,6 @@ package io.modelcontextprotocol.server; -import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.server.transport.WebMvcSseServerTransportProvider; import org.apache.catalina.Context; import org.apache.catalina.LifecycleException; @@ -36,7 +35,7 @@ static class TestConfig { @Bean public WebMvcSseServerTransportProvider webMvcSseServerTransportProvider() { - return new WebMvcSseServerTransportProvider(new ObjectMapper(), MESSAGE_ENDPOINT); + return WebMvcSseServerTransportProvider.builder().messageEndpoint(MESSAGE_ENDPOINT).build(); } @Bean @@ -49,7 +48,11 @@ public RouterFunction routerFunction(WebMvcSseServerTransportPro private AnnotationConfigWebApplicationContext appContext; @Override - protected WebMvcSseServerTransportProvider createMcpTransportProvider() { + protected McpServer.SyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(createMcpTransportProvider()); + } + + private WebMvcSseServerTransportProvider createMcpTransportProvider() { // Set up Tomcat first tomcat = new Tomcat(); tomcat.setPort(PORT); diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStatelessIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStatelessIntegrationTests.java new file mode 100644 index 000000000..8c7b0a85e --- /dev/null +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStatelessIntegrationTests.java @@ -0,0 +1,134 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ +package io.modelcontextprotocol.server; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.time.Duration; +import java.util.stream.Stream; + +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.provider.Arguments; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.web.reactive.function.client.WebClient; +import org.springframework.web.servlet.config.annotation.EnableWebMvc; +import org.springframework.web.servlet.function.RouterFunction; +import org.springframework.web.servlet.function.ServerResponse; + +import io.modelcontextprotocol.AbstractStatelessIntegrationTests; +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; +import io.modelcontextprotocol.server.McpServer.StatelessAsyncSpecification; +import io.modelcontextprotocol.server.McpServer.StatelessSyncSpecification; +import io.modelcontextprotocol.server.transport.WebMvcStatelessServerTransport; +import reactor.core.scheduler.Schedulers; + +@Timeout(15) +class WebMvcStatelessIntegrationTests extends AbstractStatelessIntegrationTests { + + private static final int PORT = TestUtil.findAvailablePort(); + + private static final String MESSAGE_ENDPOINT = "/mcp/message"; + + private WebMvcStatelessServerTransport mcpServerTransport; + + static Stream clientsForTesting() { + return Stream.of(Arguments.of("httpclient"), Arguments.of("webflux")); + } + + @Configuration + @EnableWebMvc + static class TestConfig { + + @Bean + public WebMvcStatelessServerTransport webMvcStatelessServerTransport() { + + return WebMvcStatelessServerTransport.builder().messageEndpoint(MESSAGE_ENDPOINT).build(); + + } + + @Bean + public RouterFunction routerFunction(WebMvcStatelessServerTransport statelessServerTransport) { + return statelessServerTransport.getRouterFunction(); + } + + } + + private TomcatTestUtil.TomcatServer tomcatServer; + + @Override + protected StatelessAsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(this.mcpServerTransport); + } + + @Override + protected StatelessSyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(this.mcpServerTransport); + } + + @Override + protected void prepareClients(int port, String mcpEndpoint) { + + clientBuilders.put("httpclient", McpClient + .sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + port).endpoint(mcpEndpoint).build()) + .requestTimeout(Duration.ofHours(10))); + + clientBuilders.put("webflux", + McpClient + .sync(WebClientStreamableHttpTransport + .builder(WebClient.builder().baseUrl("http://localhost:" + port)) + .endpoint(mcpEndpoint) + .build()) + .requestTimeout(Duration.ofHours(10))); + } + + @BeforeEach + public void before() { + + tomcatServer = TomcatTestUtil.createTomcatServer("", PORT, TestConfig.class); + + try { + tomcatServer.tomcat().start(); + assertThat(tomcatServer.tomcat().getServer().getState()).isEqualTo(LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + + prepareClients(PORT, MESSAGE_ENDPOINT); + + // Get the transport from Spring context + this.mcpServerTransport = tomcatServer.appContext().getBean(WebMvcStatelessServerTransport.class); + + } + + @AfterEach + public void after() { + reactor.netty.http.HttpResources.disposeLoopsAndConnections(); + if (this.mcpServerTransport != null) { + this.mcpServerTransport.closeGracefully().block(); + } + Schedulers.shutdownNow(); + if (tomcatServer.appContext() != null) { + tomcatServer.appContext().close(); + } + if (tomcatServer.tomcat() != null) { + try { + tomcatServer.tomcat().stop(); + tomcatServer.tomcat().destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + +} diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStreamableIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStreamableIntegrationTests.java new file mode 100644 index 000000000..cb7b4a2a0 --- /dev/null +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStreamableIntegrationTests.java @@ -0,0 +1,153 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ +package io.modelcontextprotocol.server; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.time.Duration; +import java.util.Map; +import java.util.stream.Stream; + +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.provider.Arguments; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.web.reactive.function.client.WebClient; +import org.springframework.web.servlet.function.ServerRequest; +import org.springframework.web.servlet.config.annotation.EnableWebMvc; +import org.springframework.web.servlet.function.RouterFunction; +import org.springframework.web.servlet.function.ServerResponse; + +import io.modelcontextprotocol.AbstractMcpClientServerIntegrationTests; +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.server.McpServer.AsyncSpecification; +import io.modelcontextprotocol.server.McpServer.SyncSpecification; +import io.modelcontextprotocol.server.transport.WebMvcStreamableServerTransportProvider; +import reactor.core.scheduler.Schedulers; + +@Timeout(15) +class WebMvcStreamableIntegrationTests extends AbstractMcpClientServerIntegrationTests { + + private static final int PORT = TestUtil.findAvailablePort(); + + private static final String MESSAGE_ENDPOINT = "/mcp/message"; + + private WebMvcStreamableServerTransportProvider mcpServerTransportProvider; + + static McpTransportContextExtractor TEST_CONTEXT_EXTRACTOR = r -> McpTransportContext + .create(Map.of("important", "value")); + + static Stream clientsForTesting() { + return Stream.of(Arguments.of("httpclient"), Arguments.of("webflux")); + } + + @Configuration + @EnableWebMvc + static class TestConfig { + + @Bean + public WebMvcStreamableServerTransportProvider webMvcStreamableServerTransportProvider() { + return WebMvcStreamableServerTransportProvider.builder() + .contextExtractor(TEST_CONTEXT_EXTRACTOR) + .mcpEndpoint(MESSAGE_ENDPOINT) + .build(); + } + + @Bean + public RouterFunction routerFunction( + WebMvcStreamableServerTransportProvider transportProvider) { + return transportProvider.getRouterFunction(); + } + + } + + private TomcatTestUtil.TomcatServer tomcatServer; + + @BeforeEach + public void before() { + + tomcatServer = TomcatTestUtil.createTomcatServer("", PORT, TestConfig.class); + + try { + tomcatServer.tomcat().start(); + assertThat(tomcatServer.tomcat().getServer().getState()).isEqualTo(LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + + clientBuilders + .put("httpclient", + McpClient.sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT) + .endpoint(MESSAGE_ENDPOINT) + .build()).initializationTimeout(Duration.ofHours(10)).requestTimeout(Duration.ofHours(10))); + + clientBuilders.put("webflux", + McpClient.sync(WebClientStreamableHttpTransport + .builder(WebClient.builder().baseUrl("http://localhost:" + PORT)) + .endpoint(MESSAGE_ENDPOINT) + .build())); + + // Get the transport from Spring context + this.mcpServerTransportProvider = tomcatServer.appContext() + .getBean(WebMvcStreamableServerTransportProvider.class); + + } + + @Override + protected AsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(this.mcpServerTransportProvider); + } + + @Override + protected SyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(this.mcpServerTransportProvider); + } + + @AfterEach + public void after() { + reactor.netty.http.HttpResources.disposeLoopsAndConnections(); + if (mcpServerTransportProvider != null) { + mcpServerTransportProvider.closeGracefully().block(); + } + Schedulers.shutdownNow(); + if (tomcatServer.appContext() != null) { + tomcatServer.appContext().close(); + } + if (tomcatServer.tomcat() != null) { + try { + tomcatServer.tomcat().stop(); + tomcatServer.tomcat().destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + + @Override + protected void prepareClients(int port, String mcpEndpoint) { + + clientBuilders.put("httpclient", McpClient + .sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + port).endpoint(mcpEndpoint).build()) + .requestTimeout(Duration.ofHours(10))); + + clientBuilders.put("webflux", + McpClient + .sync(WebClientStreamableHttpTransport + .builder(WebClient.builder().baseUrl("http://localhost:" + port)) + .endpoint(mcpEndpoint) + .build()) + .requestTimeout(Duration.ofHours(10))); + } + +} diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProviderTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProviderTests.java new file mode 100644 index 000000000..1074e8a35 --- /dev/null +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProviderTests.java @@ -0,0 +1,119 @@ +/* + * Copyright 2025 - 2025 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.TestUtil; +import io.modelcontextprotocol.server.TomcatTestUtil; +import io.modelcontextprotocol.spec.McpSchema; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.web.servlet.config.annotation.EnableWebMvc; +import org.springframework.web.servlet.function.RouterFunction; +import org.springframework.web.servlet.function.ServerResponse; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for WebMvcSseServerTransportProvider + * + * @author lance + */ +class WebMvcSseServerTransportProviderTests { + + private static final int PORT = TestUtil.findAvailablePort(); + + private static final String CUSTOM_CONTEXT_PATH = ""; + + private static final String MESSAGE_ENDPOINT = "/mcp/message"; + + private WebMvcSseServerTransportProvider mcpServerTransportProvider; + + McpClient.SyncSpec clientBuilder; + + private TomcatTestUtil.TomcatServer tomcatServer; + + @BeforeEach + public void before() { + tomcatServer = TomcatTestUtil.createTomcatServer(CUSTOM_CONTEXT_PATH, PORT, TestConfig.class); + + try { + tomcatServer.tomcat().start(); + assertThat(tomcatServer.tomcat().getServer().getState()).isEqualTo(LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + + HttpClientSseClientTransport transport = HttpClientSseClientTransport.builder("http://localhost:" + PORT) + .sseEndpoint(WebMvcSseServerTransportProvider.DEFAULT_SSE_ENDPOINT) + .build(); + + clientBuilder = McpClient.sync(transport); + mcpServerTransportProvider = tomcatServer.appContext().getBean(WebMvcSseServerTransportProvider.class); + } + + @Test + void validBaseUrl() { + McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").build(); + try (var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) + .build()) { + assertThat(client.initialize()).isNotNull(); + } + } + + @AfterEach + public void after() { + if (mcpServerTransportProvider != null) { + mcpServerTransportProvider.closeGracefully().block(); + } + if (tomcatServer.appContext() != null) { + tomcatServer.appContext().close(); + } + if (tomcatServer.tomcat() != null) { + try { + tomcatServer.tomcat().stop(); + tomcatServer.tomcat().destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + + @Configuration + @EnableWebMvc + static class TestConfig { + + @Bean + public WebMvcSseServerTransportProvider webMvcSseServerTransportProvider() { + + return WebMvcSseServerTransportProvider.builder() + .baseUrl("http://localhost:" + PORT + "/") + .messageEndpoint(MESSAGE_ENDPOINT) + .sseEndpoint(WebMvcSseServerTransportProvider.DEFAULT_SSE_ENDPOINT) + .jsonMapper(McpJsonMapper.getDefault()) + .contextExtractor(req -> McpTransportContext.EMPTY) + .build(); + } + + @Bean + public RouterFunction routerFunction(WebMvcSseServerTransportProvider transportProvider) { + return transportProvider.getRouterFunction(); + } + + } + +} diff --git a/mcp-spring/mcp-spring-webmvc/src/test/resources/logback.xml b/mcp-spring/mcp-spring-webmvc/src/test/resources/logback.xml index bc1140bb5..d4ccbc173 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/resources/logback.xml +++ b/mcp-spring/mcp-spring-webmvc/src/test/resources/logback.xml @@ -9,16 +9,16 @@ - + - + - + - + diff --git a/mcp-test/pom.xml b/mcp-test/pom.xml index f24d9fab2..7fc22e5d2 100644 --- a/mcp-test/pom.xml +++ b/mcp-test/pom.xml @@ -6,7 +6,7 @@ io.modelcontextprotocol.sdk mcp-parent - 0.11.0-SNAPSHOT + 0.18.0-SNAPSHOT mcp-test jar @@ -24,7 +24,7 @@ io.modelcontextprotocol.sdk mcp - 0.11.0-SNAPSHOT + 0.18.0-SNAPSHOT @@ -91,6 +91,13 @@ ${logback.version} + + net.javacrumbs.json-unit + json-unit-assertj + ${json-unit-assertj.version} + + + diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/AbstractMcpClientServerIntegrationTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/AbstractMcpClientServerIntegrationTests.java new file mode 100644 index 000000000..270bc4308 --- /dev/null +++ b/mcp-test/src/main/java/io/modelcontextprotocol/AbstractMcpClientServerIntegrationTests.java @@ -0,0 +1,1760 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ + +package io.modelcontextprotocol; + +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.stream.Collectors; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.McpServerFeatures; +import io.modelcontextprotocol.server.McpSyncServer; +import io.modelcontextprotocol.server.McpSyncServerExchange; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; +import io.modelcontextprotocol.spec.McpSchema.CompleteRequest; +import io.modelcontextprotocol.spec.McpSchema.CompleteResult; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.ElicitRequest; +import io.modelcontextprotocol.spec.McpSchema.ElicitResult; +import io.modelcontextprotocol.spec.McpSchema.InitializeResult; +import io.modelcontextprotocol.spec.McpSchema.ModelPreferences; +import io.modelcontextprotocol.spec.McpSchema.Prompt; +import io.modelcontextprotocol.spec.McpSchema.PromptArgument; +import io.modelcontextprotocol.spec.McpSchema.PromptReference; +import io.modelcontextprotocol.spec.McpSchema.Role; +import io.modelcontextprotocol.spec.McpSchema.Root; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.TextContent; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.util.Utils; +import net.javacrumbs.jsonunit.core.Option; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import static io.modelcontextprotocol.util.ToolsUtils.EMPTY_JSON_SCHEMA; +import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson; +import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertWith; +import static org.awaitility.Awaitility.await; +import static org.mockito.Mockito.mock; + +public abstract class AbstractMcpClientServerIntegrationTests { + + protected ConcurrentHashMap clientBuilders = new ConcurrentHashMap<>(); + + abstract protected void prepareClients(int port, String mcpEndpoint); + + abstract protected McpServer.AsyncSpecification prepareAsyncServerBuilder(); + + abstract protected McpServer.SyncSpecification prepareSyncServerBuilder(); + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void simple(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var server = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(1000)) + .build(); + try ( + // Create client without sampling capabilities + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) + .requestTimeout(Duration.ofSeconds(1000)) + .build()) { + + assertThat(client.initialize()).isNotNull(); + + } + finally { + server.closeGracefully().block(); + } + } + + // --------------------------------------- + // Sampling Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testCreateMessageWithoutSamplingCapabilities(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) + .callHandler((exchange, request) -> { + return exchange.createMessage(mock(McpSchema.CreateMessageRequest.class)) + .then(Mono.just(mock(CallToolResult.class))); + }) + .build(); + + var server = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").tools(tool).build(); + + try ( + // Create client without sampling capabilities + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) + .build()) { + + assertThat(client.initialize()).isNotNull(); + + try { + client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be configured with sampling capabilities"); + } + } + finally { + server.closeGracefully().block(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testCreateMessageSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + Function samplingHandler = request -> { + assertThat(request.messages()).hasSize(1); + assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); + + return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", + CreateMessageResult.StopReason.STOP_SEQUENCE); + }; + + CallToolResult callResponse = McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent("CALL RESPONSE")) + .build(); + + AtomicReference samplingResult = new AtomicReference<>(); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) + .callHandler((exchange, request) -> { + + var createMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("Test message")))) + .modelPreferences(ModelPreferences.builder() + .hints(List.of()) + .costPriority(1.0) + .speedPriority(1.0) + .intelligencePriority(1.0) + .build()) + .build(); + + return exchange.createMessage(createMessageRequest) + .doOnNext(samplingResult::set) + .thenReturn(callResponse); + }) + .build(); + + var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").tools(tool).build(); + + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + assertWith(samplingResult.get(), result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }); + } + finally { + mcpServer.closeGracefully().block(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testCreateMessageWithRequestTimeoutSuccess(String clientType) throws InterruptedException { + + // Client + + var clientBuilder = clientBuilders.get(clientType); + + Function samplingHandler = request -> { + assertThat(request.messages()).hasSize(1); + assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); + try { + TimeUnit.SECONDS.sleep(2); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", + CreateMessageResult.StopReason.STOP_SEQUENCE); + }; + + // Server + + CallToolResult callResponse = McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent("CALL RESPONSE")) + .build(); + + AtomicReference samplingResult = new AtomicReference<>(); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) + .callHandler((exchange, request) -> { + + var createMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("Test message")))) + .modelPreferences(ModelPreferences.builder() + .hints(List.of()) + .costPriority(1.0) + .speedPriority(1.0) + .intelligencePriority(1.0) + .build()) + .build(); + + return exchange.createMessage(createMessageRequest) + .doOnNext(samplingResult::set) + .thenReturn(callResponse); + }) + .build(); + + var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(4)) + .tools(tool) + .build(); + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + assertWith(samplingResult.get(), result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }); + } + finally { + mcpServer.closeGracefully().block(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testCreateMessageWithRequestTimeoutFail(String clientType) throws InterruptedException { + + var clientBuilder = clientBuilders.get(clientType); + + Function samplingHandler = request -> { + assertThat(request.messages()).hasSize(1); + assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); + try { + TimeUnit.SECONDS.sleep(2); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", + CreateMessageResult.StopReason.STOP_SEQUENCE); + }; + + CallToolResult callResponse = McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent("CALL RESPONSE")) + .build(); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) + .callHandler((exchange, request) -> { + + var createMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("Test message")))) + .modelPreferences(ModelPreferences.builder() + .hints(List.of()) + .costPriority(1.0) + .speedPriority(1.0) + .intelligencePriority(1.0) + .build()) + .build(); + + return exchange.createMessage(createMessageRequest).thenReturn(callResponse); + }) + .build(); + + var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(1)) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThatExceptionOfType(McpError.class).isThrownBy(() -> { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + }).withMessageContaining("1000ms"); + } + finally { + mcpServer.closeGracefully().block(); + } + } + + // --------------------------------------- + // Elicitation Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testCreateElicitationWithoutElicitationCapabilities(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) + .callHandler((exchange, request) -> exchange.createElicitation(mock(ElicitRequest.class)) + .then(Mono.just(mock(CallToolResult.class)))) + .build(); + + var server = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").tools(tool).build(); + + // Create client without elicitation capabilities + try (var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")).build()) { + + assertThat(client.initialize()).isNotNull(); + + try { + client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be configured with elicitation capabilities"); + } + } + finally { + server.closeGracefully().block(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testCreateElicitationSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + + return new McpSchema.ElicitResult(McpSchema.ElicitResult.Action.ACCEPT, + Map.of("message", request.message())); + }; + + CallToolResult callResponse = McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent("CALL RESPONSE")) + .build(); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) + .callHandler((exchange, request) -> { + + var elicitationRequest = McpSchema.ElicitRequest.builder() + .message("Test message") + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }).verifyComplete(); + + return Mono.just(callResponse); + }) + .build(); + + var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").tools(tool).build(); + + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + } + finally { + mcpServer.closeGracefully().block(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testCreateElicitationWithRequestTimeoutSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); + }; + + CallToolResult callResponse = McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent("CALL RESPONSE")) + .build(); + + AtomicReference resultRef = new AtomicReference<>(); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) + .callHandler((exchange, request) -> { + + var elicitationRequest = McpSchema.ElicitRequest.builder() + .message("Test message") + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + return exchange.createElicitation(elicitationRequest) + .doOnNext(resultRef::set) + .then(Mono.just(callResponse)); + }) + .build(); + + var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(3)) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + assertWith(resultRef.get(), result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }); + } + finally { + mcpServer.closeGracefully().block(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testCreateElicitationWithRequestTimeoutFail(String clientType) { + + var latch = new CountDownLatch(1); + + var clientBuilder = clientBuilders.get(clientType); + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + + try { + if (!latch.await(2, TimeUnit.SECONDS)) { + throw new RuntimeException("Timeout waiting for elicitation processing"); + } + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); + }; + + CallToolResult callResponse = CallToolResult.builder().addContent(new TextContent("CALL RESPONSE")).build(); + + AtomicReference resultRef = new AtomicReference<>(); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) + .callHandler((exchange, request) -> { + + var elicitationRequest = ElicitRequest.builder() + .message("Test message") + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + return exchange.createElicitation(elicitationRequest) + .doOnNext(resultRef::set) + .then(Mono.just(callResponse)); + }) + .build(); + + var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(1)) // 1 second. + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThatExceptionOfType(McpError.class).isThrownBy(() -> { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + }).withMessageContaining("within 1000ms"); + + ElicitResult elicitResult = resultRef.get(); + assertThat(elicitResult).isNull(); + } + finally { + mcpServer.closeGracefully().block(); + } + } + + // --------------------------------------- + // Roots Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testRootsSuccess(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + List roots = List.of(new Root("uri1://", "root1"), new Root("uri2://", "root2")); + + AtomicReference> rootsRef = new AtomicReference<>(); + + var mcpServer = prepareSyncServerBuilder() + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) + .build(); + + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(roots) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(rootsRef.get()).isNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(roots); + }); + + // Remove a root + mcpClient.removeRoot(roots.get(0).uri()); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(roots.get(1))); + }); + + // Add a new root + var root3 = new Root("uri3://", "root3"); + mcpClient.addRoot(root3); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(roots.get(1), root3)); + }); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testRootsWithoutCapability(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + McpServerFeatures.SyncToolSpecification tool = McpServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) + .callHandler((exchange, request) -> { + + exchange.listRoots(); // try to list roots + + return mock(CallToolResult.class); + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().rootsChangeHandler((exchange, rootsUpdate) -> { + }).tools(tool).build(); + + try ( + // Create client without roots capability + // No roots capability + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()).build()) { + + assertThat(mcpClient.initialize()).isNotNull(); + + // Attempt to list roots should fail + try { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class).hasMessage("Roots not supported"); + } + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testRootsNotificationWithEmptyRootsList(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + AtomicReference> rootsRef = new AtomicReference<>(); + + var mcpServer = prepareSyncServerBuilder() + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) + .build(); + + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(List.of()) // Empty roots list + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).isEmpty(); + }); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testRootsWithMultipleHandlers(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + List roots = List.of(new Root("uri1://", "root1")); + + AtomicReference> rootsRef1 = new AtomicReference<>(); + AtomicReference> rootsRef2 = new AtomicReference<>(); + + var mcpServer = prepareSyncServerBuilder() + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef1.set(rootsUpdate)) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef2.set(rootsUpdate)) + .build(); + + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(roots) + .build()) { + + assertThat(mcpClient.initialize()).isNotNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef1.get()).containsAll(roots); + assertThat(rootsRef2.get()).containsAll(roots); + }); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testRootsServerCloseWithActiveSubscription(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + List roots = List.of(new Root("uri1://", "root1")); + + AtomicReference> rootsRef = new AtomicReference<>(); + + var mcpServer = prepareSyncServerBuilder() + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) + .build(); + + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(roots) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(roots); + }); + } + finally { + mcpServer.closeGracefully(); + } + } + + // --------------------------------------- + // Tools Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testToolCallSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var responseBodyIsNullOrBlank = new AtomicBoolean(false); + var callResponse = McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent("CALL RESPONSE; ctx=importantValue")) + .build(); + McpServerFeatures.SyncToolSpecification tool1 = McpServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) + .callHandler((exchange, request) -> { + + try { + HttpResponse response = HttpClient.newHttpClient() + .send(HttpRequest.newBuilder() + .uri(URI.create( + "https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md")) + .GET() + .build(), HttpResponse.BodyHandlers.ofString()); + String responseBody = response.body(); + responseBodyIsNullOrBlank.set(!Utils.hasText(responseBody)); + } + catch (Exception e) { + e.printStackTrace(); + } + + return callResponse; + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(responseBodyIsNullOrBlank.get()).isFalse(); + assertThat(response).isNotNull().isEqualTo(callResponse); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testThrowingToolCallIsCaughtBeforeTimeout(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + McpSyncServer mcpServer = prepareSyncServerBuilder() + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(McpServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder() + .name("tool1") + .description("tool1 description") + .inputSchema(EMPTY_JSON_SCHEMA) + .build()) + .callHandler((exchange, request) -> { + // We trigger a timeout on blocking read, raising an exception + Mono.never().block(Duration.ofSeconds(1)); + return null; + }) + .build()) + .build(); + + try (var mcpClient = clientBuilder.requestTimeout(Duration.ofMillis(6666)).build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // We expect the tool call to fail immediately with the exception raised by + // the offending tool instead of getting back a timeout. + assertThatExceptionOfType(McpError.class) + .isThrownBy(() -> mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of()))) + .withMessageContaining("Timeout on blocking read"); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testToolCallSuccessWithTranportContextExtraction(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var transportContextIsNull = new AtomicBoolean(false); + var transportContextIsEmpty = new AtomicBoolean(false); + var responseBodyIsNullOrBlank = new AtomicBoolean(false); + + var expectedCallResponse = McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent("CALL RESPONSE; ctx=value")) + .build(); + McpServerFeatures.SyncToolSpecification tool1 = McpServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) + .callHandler((exchange, request) -> { + + McpTransportContext transportContext = exchange.transportContext(); + transportContextIsNull.set(transportContext == null); + transportContextIsEmpty.set(transportContext.equals(McpTransportContext.EMPTY)); + String ctxValue = (String) transportContext.get("important"); + + try { + String responseBody = "TOOL RESPONSE"; + responseBodyIsNullOrBlank.set(!Utils.hasText(responseBody)); + } + catch (Exception e) { + e.printStackTrace(); + } + + return McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent("CALL RESPONSE; ctx=" + ctxValue)) + .build(); + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(transportContextIsNull.get()).isFalse(); + assertThat(transportContextIsEmpty.get()).isFalse(); + assertThat(responseBodyIsNullOrBlank.get()).isFalse(); + assertThat(response).isNotNull().isEqualTo(expectedCallResponse); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testToolListChangeHandlingSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var callResponse = McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent("CALL RESPONSE")) + .build(); + + McpServerFeatures.SyncToolSpecification tool1 = McpServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) + .callHandler((exchange, request) -> { + // perform a blocking call to a remote service + try { + HttpResponse response = HttpClient.newHttpClient() + .send(HttpRequest.newBuilder() + .uri(URI.create( + "https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md")) + .GET() + .build(), HttpResponse.BodyHandlers.ofString()); + String responseBody = response.body(); + assertThat(responseBody).isNotBlank(); + } + catch (Exception e) { + e.printStackTrace(); + } + return callResponse; + }) + .build(); + + AtomicReference> toolsRef = new AtomicReference<>(); + + var mcpServer = prepareSyncServerBuilder().capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + try (var mcpClient = clientBuilder.toolsChangeConsumer(toolsUpdate -> { + // perform a blocking call to a remote service + try { + HttpResponse response = HttpClient.newHttpClient() + .send(HttpRequest.newBuilder() + .uri(URI.create( + "https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md")) + .GET() + .build(), HttpResponse.BodyHandlers.ofString()); + String responseBody = response.body(); + assertThat(responseBody).isNotBlank(); + toolsRef.set(toolsUpdate); + } + catch (Exception e) { + e.printStackTrace(); + } + }).build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(toolsRef.get()).isNull(); + + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + + mcpServer.notifyToolsListChanged(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(toolsRef.get()).containsAll(List.of(tool1.tool())); + }); + + // Remove a tool + mcpServer.removeTool("tool1"); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(toolsRef.get()).isEmpty(); + }); + + // Add a new tool + McpServerFeatures.SyncToolSpecification tool2 = McpServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder() + .name("tool2") + .description("tool2 description") + .inputSchema(EMPTY_JSON_SCHEMA) + .build()) + .callHandler((exchange, request) -> callResponse) + .build(); + + mcpServer.addTool(tool2); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(toolsRef.get()).containsAll(List.of(tool2.tool())); + }); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testInitialize(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var mcpServer = prepareSyncServerBuilder().build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + } + finally { + mcpServer.closeGracefully(); + } + } + + // --------------------------------------- + // Logging Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testLoggingNotification(String clientType) throws InterruptedException { + int expectedNotificationsCount = 3; + CountDownLatch latch = new CountDownLatch(expectedNotificationsCount); + // Create a list to store received logging notifications + List receivedNotifications = new CopyOnWriteArrayList<>(); + + var clientBuilder = clientBuilders.get(clientType); + + // Create server with a tool that sends logging notifications + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder() + .name("logging-test") + .description("Test logging notifications") + .inputSchema(EMPTY_JSON_SCHEMA) + .build()) + .callHandler((exchange, request) -> { + + // Create and send notifications with different levels + + //@formatter:off + return exchange // This should be filtered out (DEBUG < NOTICE) + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.DEBUG) + .logger("test-logger") + .data("Debug message") + .build()) + .then(exchange // This should be sent (NOTICE >= NOTICE) + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.NOTICE) + .logger("test-logger") + .data("Notice message") + .build())) + .then(exchange // This should be sent (ERROR > NOTICE) + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.ERROR) + .logger("test-logger") + .data("Error message") + .build())) + .then(exchange // This should be filtered out (INFO < NOTICE) + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.INFO) + .logger("test-logger") + .data("Another info message") + .build())) + .then(exchange // This should be sent (ERROR >= NOTICE) + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.ERROR) + .logger("test-logger") + .data("Another error message") + .build())) + .thenReturn(CallToolResult.builder() + .content(List.of(new McpSchema.TextContent("Logging test completed"))) + .isError(false) + .build()); + //@formatter:on + }) + .build(); + + var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try ( + // Create client with logging notification handler + var mcpClient = clientBuilder.loggingConsumer(notification -> { + receivedNotifications.add(notification); + latch.countDown(); + }).build()) { + + // Initialize client + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Set minimum logging level to NOTICE + mcpClient.setLoggingLevel(McpSchema.LoggingLevel.NOTICE); + + // Call the tool that sends logging notifications + CallToolResult result = mcpClient.callTool(new McpSchema.CallToolRequest("logging-test", Map.of())); + assertThat(result).isNotNull(); + assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Logging test completed"); + + assertThat(latch.await(5, TimeUnit.SECONDS)).as("Should receive notifications in reasonable time").isTrue(); + + // Should have received 3 notifications (1 NOTICE and 2 ERROR) + assertThat(receivedNotifications).hasSize(expectedNotificationsCount); + + Map notificationMap = receivedNotifications.stream() + .collect(Collectors.toMap(n -> n.data(), n -> n)); + + // First notification should be NOTICE level + assertThat(notificationMap.get("Notice message").level()).isEqualTo(McpSchema.LoggingLevel.NOTICE); + assertThat(notificationMap.get("Notice message").logger()).isEqualTo("test-logger"); + assertThat(notificationMap.get("Notice message").data()).isEqualTo("Notice message"); + + // Second notification should be ERROR level + assertThat(notificationMap.get("Error message").level()).isEqualTo(McpSchema.LoggingLevel.ERROR); + assertThat(notificationMap.get("Error message").logger()).isEqualTo("test-logger"); + assertThat(notificationMap.get("Error message").data()).isEqualTo("Error message"); + + // Third notification should be ERROR level + assertThat(notificationMap.get("Another error message").level()).isEqualTo(McpSchema.LoggingLevel.ERROR); + assertThat(notificationMap.get("Another error message").logger()).isEqualTo("test-logger"); + assertThat(notificationMap.get("Another error message").data()).isEqualTo("Another error message"); + } + finally { + mcpServer.closeGracefully().block(); + } + } + + // --------------------------------------- + // Progress Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testProgressNotification(String clientType) throws InterruptedException { + int expectedNotificationsCount = 4; // 3 notifications + 1 for another progress + // token + CountDownLatch latch = new CountDownLatch(expectedNotificationsCount); + // Create a list to store received logging notifications + List receivedNotifications = new CopyOnWriteArrayList<>(); + + var clientBuilder = clientBuilders.get(clientType); + + // Create server with a tool that sends logging notifications + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(McpSchema.Tool.builder() + .name("progress-test") + .description("Test progress notifications") + .inputSchema(EMPTY_JSON_SCHEMA) + .build()) + .callHandler((exchange, request) -> { + + // Create and send notifications + var progressToken = (String) request.meta().get("progressToken"); + + return exchange + .progressNotification( + new McpSchema.ProgressNotification(progressToken, 0.0, 1.0, "Processing started")) + .then(exchange.progressNotification( + new McpSchema.ProgressNotification(progressToken, 0.5, 1.0, "Processing data"))) + .then(// Send a progress notification with another progress value + // should + exchange.progressNotification(new McpSchema.ProgressNotification("another-progress-token", + 0.0, 1.0, "Another processing started"))) + .then(exchange.progressNotification( + new McpSchema.ProgressNotification(progressToken, 1.0, 1.0, "Processing completed"))) + .thenReturn(CallToolResult.builder() + .content(List.of(new McpSchema.TextContent("Progress test completed"))) + .isError(false) + .build()); + }) + .build(); + + var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try ( + // Create client with progress notification handler + var mcpClient = clientBuilder.progressConsumer(notification -> { + receivedNotifications.add(notification); + latch.countDown(); + }).build()) { + + // Initialize client + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Call the tool that sends progress notifications + McpSchema.CallToolRequest callToolRequest = McpSchema.CallToolRequest.builder() + .name("progress-test") + .meta(Map.of("progressToken", "test-progress-token")) + .build(); + CallToolResult result = mcpClient.callTool(callToolRequest); + assertThat(result).isNotNull(); + assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Progress test completed"); + + assertThat(latch.await(5, TimeUnit.SECONDS)).as("Should receive notifications in reasonable time").isTrue(); + + // Should have received 3 notifications + assertThat(receivedNotifications).hasSize(expectedNotificationsCount); + + Map notificationMap = receivedNotifications.stream() + .collect(Collectors.toMap(n -> n.message(), n -> n)); + + // First notification should be 0.0/1.0 progress + assertThat(notificationMap.get("Processing started").progressToken()).isEqualTo("test-progress-token"); + assertThat(notificationMap.get("Processing started").progress()).isEqualTo(0.0); + assertThat(notificationMap.get("Processing started").total()).isEqualTo(1.0); + assertThat(notificationMap.get("Processing started").message()).isEqualTo("Processing started"); + + // Second notification should be 0.5/1.0 progress + assertThat(notificationMap.get("Processing data").progressToken()).isEqualTo("test-progress-token"); + assertThat(notificationMap.get("Processing data").progress()).isEqualTo(0.5); + assertThat(notificationMap.get("Processing data").total()).isEqualTo(1.0); + assertThat(notificationMap.get("Processing data").message()).isEqualTo("Processing data"); + + // Third notification should be another progress token with 0.0/1.0 progress + assertThat(notificationMap.get("Another processing started").progressToken()) + .isEqualTo("another-progress-token"); + assertThat(notificationMap.get("Another processing started").progress()).isEqualTo(0.0); + assertThat(notificationMap.get("Another processing started").total()).isEqualTo(1.0); + assertThat(notificationMap.get("Another processing started").message()) + .isEqualTo("Another processing started"); + + // Fourth notification should be 1.0/1.0 progress + assertThat(notificationMap.get("Processing completed").progressToken()).isEqualTo("test-progress-token"); + assertThat(notificationMap.get("Processing completed").progress()).isEqualTo(1.0); + assertThat(notificationMap.get("Processing completed").total()).isEqualTo(1.0); + assertThat(notificationMap.get("Processing completed").message()).isEqualTo("Processing completed"); + } + finally { + mcpServer.closeGracefully().block(); + } + } + + // --------------------------------------- + // Completion Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : Completion call") + @MethodSource("clientsForTesting") + void testCompletionShouldReturnExpectedSuggestions(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + var expectedValues = List.of("python", "pytorch", "pyside"); + var completionResponse = new McpSchema.CompleteResult(new CompleteResult.CompleteCompletion(expectedValues, 10, // total + true // hasMore + )); + + AtomicReference samplingRequest = new AtomicReference<>(); + BiFunction completionHandler = (mcpSyncServerExchange, + request) -> { + samplingRequest.set(request); + return completionResponse; + }; + + var mcpServer = prepareSyncServerBuilder().capabilities(ServerCapabilities.builder().completions().build()) + .prompts(new McpServerFeatures.SyncPromptSpecification( + new Prompt("code_review", "Code review", "this is code review prompt", + List.of(new PromptArgument("language", "Language", "string", false))), + (mcpSyncServerExchange, getPromptRequest) -> null)) + .completions(new McpServerFeatures.SyncCompletionSpecification( + new McpSchema.PromptReference(PromptReference.TYPE, "code_review", "Code review"), + completionHandler)) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CompleteRequest request = new CompleteRequest( + new PromptReference(PromptReference.TYPE, "code_review", "Code review"), + new CompleteRequest.CompleteArgument("language", "py")); + + CompleteResult result = mcpClient.completeCompletion(request); + + assertThat(result).isNotNull(); + + assertThat(samplingRequest.get().argument().name()).isEqualTo("language"); + assertThat(samplingRequest.get().argument().value()).isEqualTo("py"); + assertThat(samplingRequest.get().ref().type()).isEqualTo(PromptReference.TYPE); + } + finally { + mcpServer.closeGracefully(); + } + } + + // --------------------------------------- + // Ping Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testPingSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + // Create server with a tool that uses ping functionality + AtomicReference executionOrder = new AtomicReference<>(""); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder() + .name("ping-async-test") + .description("Test ping async behavior") + .inputSchema(EMPTY_JSON_SCHEMA) + .build()) + .callHandler((exchange, request) -> { + + executionOrder.set(executionOrder.get() + "1"); + + // Test async ping behavior + return exchange.ping().doOnNext(result -> { + + assertThat(result).isNotNull(); + // Ping should return an empty object or map + assertThat(result).isInstanceOf(Map.class); + + executionOrder.set(executionOrder.get() + "2"); + assertThat(result).isNotNull(); + }).then(Mono.fromCallable(() -> { + executionOrder.set(executionOrder.get() + "3"); + return CallToolResult.builder() + .content(List.of(new McpSchema.TextContent("Async ping test completed"))) + .isError(false) + .build(); + })); + }) + .build(); + + var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + // Initialize client + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Call the tool that tests ping async behavior + CallToolResult result = mcpClient.callTool(new McpSchema.CallToolRequest("ping-async-test", Map.of())); + assertThat(result).isNotNull(); + assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Async ping test completed"); + + // Verify execution order + assertThat(executionOrder.get()).isEqualTo("123"); + } + finally { + mcpServer.closeGracefully().block(); + } + } + + // --------------------------------------- + // Tool Structured Output Schema Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testStructuredOutputValidationSuccess(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of( + "type", "object", "properties", Map.of("result", Map.of("type", "number"), "operation", + Map.of("type", "string"), "timestamp", Map.of("type", "string")), + "required", List.of("result", "operation")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + McpServerFeatures.SyncToolSpecification tool = McpServerFeatures.SyncToolSpecification.builder() + .tool(calculatorTool) + .callHandler((exchange, request) -> { + String expression = (String) request.arguments().getOrDefault("expression", "2 + 3"); + double result = evaluateExpression(expression); + return CallToolResult.builder() + .structuredContent( + Map.of("result", result, "operation", expression, "timestamp", "2024-01-01T10:00:00Z")) + .build(); + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Verify tool is listed with output schema + var toolsList = mcpClient.listTools(); + assertThat(toolsList.tools()).hasSize(1); + assertThat(toolsList.tools().get(0).name()).isEqualTo("calculator"); + // Note: outputSchema might be null in sync server, but validation still works + + // Call tool with valid structured output + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isFalse(); + + // In WebMVC, structured content is returned properly + if (response.structuredContent() != null) { + assertThat((Map) response.structuredContent()).containsEntry("result", 5.0) + .containsEntry("operation", "2 + 3") + .containsEntry("timestamp", "2024-01-01T10:00:00Z"); + } + else { + // Fallback to checking content if structured content is not available + assertThat(response.content()).isNotEmpty(); + } + + assertThat(response.structuredContent()).isNotNull(); + assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"result":5.0,"operation":"2 + 3","timestamp":"2024-01-01T10:00:00Z"}""")); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testStructuredOutputOfObjectArrayValidationSuccess(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema that returns an array of objects + Map outputSchema = Map + .of( // @formatter:off + "type", "array", + "items", Map.of( + "type", "object", + "properties", Map.of( + "name", Map.of("type", "string"), + "age", Map.of("type", "number")), + "required", List.of("name", "age"))); // @formatter:on + + Tool calculatorTool = Tool.builder() + .name("getMembers") + .description("Returns a list of members") + .outputSchema(outputSchema) + .build(); + + McpServerFeatures.SyncToolSpecification tool = McpServerFeatures.SyncToolSpecification.builder() + .tool(calculatorTool) + .callHandler((exchange, request) -> { + return CallToolResult.builder() + .structuredContent(List.of(Map.of("name", "John", "age", 30), Map.of("name", "Peter", "age", 25))) + .build(); + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + assertThat(mcpClient.initialize()).isNotNull(); + + // Call tool with valid structured output of type array + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("getMembers", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isFalse(); + + assertThat(response.structuredContent()).isNotNull(); + assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isArray() + .hasSize(2) + .containsExactlyInAnyOrder(json(""" + {"name":"John","age":30}"""), json(""" + {"name":"Peter","age":25}""")); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testStructuredOutputWithInHandlerError(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of( + "type", "object", "properties", Map.of("result", Map.of("type", "number"), "operation", + Map.of("type", "string"), "timestamp", Map.of("type", "string")), + "required", List.of("result", "operation")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + // Handler that returns an error result + McpServerFeatures.SyncToolSpecification tool = McpServerFeatures.SyncToolSpecification.builder() + .tool(calculatorTool) + .callHandler((exchange, request) -> CallToolResult.builder() + .isError(true) + .content(List.of(new TextContent("Error calling tool: Simulated in-handler error"))) + .build()) + .build(); + + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Verify tool is listed with output schema + var toolsList = mcpClient.listTools(); + assertThat(toolsList.tools()).hasSize(1); + assertThat(toolsList.tools().get(0).name()).isEqualTo("calculator"); + // Note: outputSchema might be null in sync server, but validation still works + + // Call tool with valid structured output + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isTrue(); + assertThat(response.content()).isNotEmpty(); + assertThat(response.content()) + .containsExactly(new McpSchema.TextContent("Error calling tool: Simulated in-handler error")); + assertThat(response.structuredContent()).isNull(); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testStructuredOutputValidationFailure(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of("type", "object", "properties", + Map.of("result", Map.of("type", "number"), "operation", Map.of("type", "string")), "required", + List.of("result", "operation")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + McpServerFeatures.SyncToolSpecification tool = McpServerFeatures.SyncToolSpecification.builder() + .tool(calculatorTool) + .callHandler((exchange, request) -> { + // Return invalid structured output. Result should be number, missing + // operation + return CallToolResult.builder() + .addTextContent("Invalid calculation") + .structuredContent(Map.of("result", "not-a-number", "extra", "field")) + .build(); + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Call tool with invalid structured output + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isTrue(); + assertThat(response.content()).hasSize(1); + assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + + String errorMessage = ((McpSchema.TextContent) response.content().get(0)).text(); + assertThat(errorMessage).contains("Validation failed"); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testStructuredOutputMissingStructuredContent(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of("type", "object", "properties", + Map.of("result", Map.of("type", "number")), "required", List.of("result")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + McpServerFeatures.SyncToolSpecification tool = McpServerFeatures.SyncToolSpecification.builder() + .tool(calculatorTool) + .callHandler((exchange, request) -> { + // Return result without structured content but tool has output schema + return CallToolResult.builder().addTextContent("Calculation completed").build(); + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Call tool that should return structured content but doesn't + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isTrue(); + assertThat(response.content()).hasSize(1); + assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + + String errorMessage = ((McpSchema.TextContent) response.content().get(0)).text(); + assertThat(errorMessage).isEqualTo( + "Response missing structured content which is expected when calling tool with non-empty outputSchema"); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testStructuredOutputRuntimeToolAddition(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + // Start server without tools + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Initially no tools + assertThat(mcpClient.listTools().tools()).isEmpty(); + + // Add tool with output schema at runtime + Map outputSchema = Map.of("type", "object", "properties", + Map.of("message", Map.of("type", "string"), "count", Map.of("type", "integer")), "required", + List.of("message", "count")); + + Tool dynamicTool = Tool.builder() + .name("dynamic-tool") + .description("Dynamically added tool") + .outputSchema(outputSchema) + .build(); + + McpServerFeatures.SyncToolSpecification toolSpec = McpServerFeatures.SyncToolSpecification.builder() + .tool(dynamicTool) + .callHandler((exchange, request) -> { + int count = (Integer) request.arguments().getOrDefault("count", 1); + return CallToolResult.builder() + .addTextContent("Dynamic tool executed " + count + " times") + .structuredContent(Map.of("message", "Dynamic execution", "count", count)) + .build(); + }) + .build(); + + // Add tool to server + mcpServer.addTool(toolSpec); + + // Wait for tool list change notification + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(mcpClient.listTools().tools()).hasSize(1); + }); + + // Verify tool was added with output schema + var toolsList = mcpClient.listTools(); + assertThat(toolsList.tools()).hasSize(1); + assertThat(toolsList.tools().get(0).name()).isEqualTo("dynamic-tool"); + // Note: outputSchema might be null in sync server, but validation still works + + // Call dynamically added tool + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("dynamic-tool", Map.of("count", 3))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isFalse(); + + assertThat(response.content()).hasSize(1); + assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) response.content().get(0)).text()) + .isEqualTo("Dynamic tool executed 3 times"); + + assertThat(response.structuredContent()).isNotNull(); + assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"count":3,"message":"Dynamic execution"}""")); + } + finally { + mcpServer.closeGracefully(); + } + } + + private double evaluateExpression(String expression) { + // Simple expression evaluator for testing + return switch (expression) { + case "2 + 3" -> 5.0; + case "10 * 2" -> 20.0; + case "7 + 8" -> 15.0; + case "5 + 3" -> 8.0; + default -> 0.0; + }; + } + +} diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/AbstractStatelessIntegrationTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/AbstractStatelessIntegrationTests.java new file mode 100644 index 000000000..240732ebe --- /dev/null +++ b/mcp-test/src/main/java/io/modelcontextprotocol/AbstractStatelessIntegrationTests.java @@ -0,0 +1,659 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ + +package io.modelcontextprotocol; + +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicReference; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.server.McpServer.StatelessAsyncSpecification; +import io.modelcontextprotocol.server.McpServer.StatelessSyncSpecification; +import io.modelcontextprotocol.server.McpStatelessServerFeatures; +import io.modelcontextprotocol.server.McpStatelessSyncServer; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.InitializeResult; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.TextContent; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import net.javacrumbs.jsonunit.core.Option; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; +import reactor.core.publisher.Mono; + +import static io.modelcontextprotocol.util.ToolsUtils.EMPTY_JSON_SCHEMA; +import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson; +import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.awaitility.Awaitility.await; + +public abstract class AbstractStatelessIntegrationTests { + + protected ConcurrentHashMap clientBuilders = new ConcurrentHashMap<>(); + + abstract protected void prepareClients(int port, String mcpEndpoint); + + abstract protected StatelessAsyncSpecification prepareAsyncServerBuilder(); + + abstract protected StatelessSyncSpecification prepareSyncServerBuilder(); + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void simple(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var server = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(1000)) + .build(); + + try ( + // Create client without sampling capabilities + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) + .requestTimeout(Duration.ofSeconds(1000)) + .build()) { + + assertThat(client.initialize()).isNotNull(); + + } + finally { + server.closeGracefully().block(); + } + } + + // --------------------------------------- + // Tools Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testToolCallSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + McpStatelessServerFeatures.SyncToolSpecification tool1 = McpStatelessServerFeatures.SyncToolSpecification + .builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) + .callHandler((ctx, request) -> { + + try { + HttpResponse response = HttpClient.newHttpClient() + .send(HttpRequest.newBuilder() + .uri(URI.create( + "https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md")) + .GET() + .build(), HttpResponse.BodyHandlers.ofString()); + String responseBody = response.body(); + assertThat(responseBody).isNotBlank(); + } + catch (Exception e) { + e.printStackTrace(); + } + + return callResponse; + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull().isEqualTo(callResponse); + } + finally { + mcpServer.closeGracefully().block(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testThrowingToolCallIsCaughtBeforeTimeout(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + McpStatelessSyncServer mcpServer = prepareSyncServerBuilder() + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(McpStatelessServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder() + .name("tool1") + .description("tool1 description") + .inputSchema(EMPTY_JSON_SCHEMA) + .build()) + .callHandler((context, request) -> { + // We trigger a timeout on blocking read, raising an exception + Mono.never().block(Duration.ofSeconds(1)); + return null; + }) + .build()) + .build(); + + try (var mcpClient = clientBuilder.requestTimeout(Duration.ofMillis(6666)).build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // We expect the tool call to fail immediately with the exception raised by + // the offending tool + // instead of getting back a timeout. + assertThatExceptionOfType(McpError.class) + .isThrownBy(() -> mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of()))) + .withMessageContaining("Timeout on blocking read"); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testToolListChangeHandlingSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + McpStatelessServerFeatures.SyncToolSpecification tool1 = McpStatelessServerFeatures.SyncToolSpecification + .builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) + .callHandler((ctx, request) -> { + // perform a blocking call to a remote service + try { + HttpResponse response = HttpClient.newHttpClient() + .send(HttpRequest.newBuilder() + .uri(URI.create( + "https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md")) + .GET() + .build(), HttpResponse.BodyHandlers.ofString()); + String responseBody = response.body(); + assertThat(responseBody).isNotBlank(); + } + catch (Exception e) { + e.printStackTrace(); + } + return callResponse; + }) + .build(); + + AtomicReference> rootsRef = new AtomicReference<>(); + + var mcpServer = prepareSyncServerBuilder().capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + try (var mcpClient = clientBuilder.toolsChangeConsumer(toolsUpdate -> { + // perform a blocking call to a remote service + try { + HttpResponse response = HttpClient.newHttpClient() + .send(HttpRequest.newBuilder() + .uri(URI.create( + "https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md")) + .GET() + .build(), HttpResponse.BodyHandlers.ofString()); + String responseBody = response.body(); + assertThat(responseBody).isNotBlank(); + } + catch (Exception e) { + e.printStackTrace(); + } + + rootsRef.set(toolsUpdate); + }).build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(rootsRef.get()).isNull(); + + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + + // Remove a tool + mcpServer.removeTool("tool1"); + + // Add a new tool + McpStatelessServerFeatures.SyncToolSpecification tool2 = McpStatelessServerFeatures.SyncToolSpecification + .builder() + .tool(Tool.builder() + .name("tool2") + .description("tool2 description") + .inputSchema(EMPTY_JSON_SCHEMA) + .build()) + .callHandler((exchange, request) -> callResponse) + .build(); + + mcpServer.addTool(tool2); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testInitialize(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var mcpServer = prepareSyncServerBuilder().build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + } + finally { + mcpServer.closeGracefully(); + } + } + + // --------------------------------------- + // Tool Structured Output Schema Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testStructuredOutputValidationSuccess(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of( + "type", "object", "properties", Map.of("result", Map.of("type", "number"), "operation", + Map.of("type", "string"), "timestamp", Map.of("type", "string")), + "required", List.of("result", "operation")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + McpStatelessServerFeatures.SyncToolSpecification tool = McpStatelessServerFeatures.SyncToolSpecification + .builder() + .tool(calculatorTool) + .callHandler((exchange, request) -> { + String expression = (String) request.arguments().getOrDefault("expression", "2 + 3"); + double result = evaluateExpression(expression); + return CallToolResult.builder() + .structuredContent( + Map.of("result", result, "operation", expression, "timestamp", "2024-01-01T10:00:00Z")) + .build(); + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Verify tool is listed with output schema + var toolsList = mcpClient.listTools(); + assertThat(toolsList.tools()).hasSize(1); + assertThat(toolsList.tools().get(0).name()).isEqualTo("calculator"); + // Note: outputSchema might be null in sync server, but validation still works + + // Call tool with valid structured output + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isFalse(); + + // In WebMVC, structured content is returned properly + if (response.structuredContent() != null) { + assertThat((Map) response.structuredContent()).containsEntry("result", 5.0) + .containsEntry("operation", "2 + 3") + .containsEntry("timestamp", "2024-01-01T10:00:00Z"); + } + else { + // Fallback to checking content if structured content is not available + assertThat(response.content()).isNotEmpty(); + } + + assertThat(response.structuredContent()).isNotNull(); + assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"result":5.0,"operation":"2 + 3","timestamp":"2024-01-01T10:00:00Z"}""")); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testStructuredOutputOfObjectArrayValidationSuccess(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema that returns an array of objects + Map outputSchema = Map + .of( // @formatter:off + "type", "array", + "items", Map.of( + "type", "object", + "properties", Map.of( + "name", Map.of("type", "string"), + "age", Map.of("type", "number")), + "required", List.of("name", "age"))); // @formatter:on + + Tool calculatorTool = Tool.builder() + .name("getMembers") + .description("Returns a list of members") + .outputSchema(outputSchema) + .build(); + + McpStatelessServerFeatures.SyncToolSpecification tool = McpStatelessServerFeatures.SyncToolSpecification + .builder() + .tool(calculatorTool) + .callHandler((exchange, request) -> { + return CallToolResult.builder() + .structuredContent(List.of(Map.of("name", "John", "age", 30), Map.of("name", "Peter", "age", 25))) + .build(); + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + assertThat(mcpClient.initialize()).isNotNull(); + + // Call tool with valid structured output of type array + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("getMembers", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isFalse(); + + assertThat(response.structuredContent()).isNotNull(); + assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isArray() + .hasSize(2) + .containsExactlyInAnyOrder(json(""" + {"name":"John","age":30}"""), json(""" + {"name":"Peter","age":25}""")); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testStructuredOutputWithInHandlerError(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of( + "type", "object", "properties", Map.of("result", Map.of("type", "number"), "operation", + Map.of("type", "string"), "timestamp", Map.of("type", "string")), + "required", List.of("result", "operation")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + // Handler that throws an exception to simulate an error + McpStatelessServerFeatures.SyncToolSpecification tool = McpStatelessServerFeatures.SyncToolSpecification + .builder() + .tool(calculatorTool) + .callHandler((exchange, request) -> CallToolResult.builder() + .isError(true) + .content(List.of(new TextContent("Error calling tool: Simulated in-handler error"))) + .build()) + .build(); + + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Verify tool is listed with output schema + var toolsList = mcpClient.listTools(); + assertThat(toolsList.tools()).hasSize(1); + assertThat(toolsList.tools().get(0).name()).isEqualTo("calculator"); + // Note: outputSchema might be null in sync server, but validation still works + + // Call tool with valid structured output + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isTrue(); + assertThat(response.content()).isNotEmpty(); + assertThat(response.content()) + .containsExactly(new McpSchema.TextContent("Error calling tool: Simulated in-handler error")); + assertThat(response.structuredContent()).isNull(); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testStructuredOutputValidationFailure(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of("type", "object", "properties", + Map.of("result", Map.of("type", "number"), "operation", Map.of("type", "string")), "required", + List.of("result", "operation")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + McpStatelessServerFeatures.SyncToolSpecification tool = McpStatelessServerFeatures.SyncToolSpecification + .builder() + .tool(calculatorTool) + .callHandler((exchange, request) -> { + // Return invalid structured output. Result should be number, missing + // operation + return CallToolResult.builder() + .addTextContent("Invalid calculation") + .structuredContent(Map.of("result", "not-a-number", "extra", "field")) + .build(); + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Call tool with invalid structured output + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isTrue(); + assertThat(response.content()).hasSize(1); + assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + + String errorMessage = ((McpSchema.TextContent) response.content().get(0)).text(); + assertThat(errorMessage).contains("Validation failed"); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testStructuredOutputMissingStructuredContent(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of("type", "object", "properties", + Map.of("result", Map.of("type", "number")), "required", List.of("result")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + var tool = McpStatelessServerFeatures.SyncToolSpecification.builder() + .tool(calculatorTool) + .callHandler((exchange, request) -> { + // Return result without structured content but tool has output schema + return CallToolResult.builder().addTextContent("Calculation completed").build(); + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Call tool that should return structured content but doesn't + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isTrue(); + assertThat(response.content()).hasSize(1); + assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + + String errorMessage = ((McpSchema.TextContent) response.content().get(0)).text(); + assertThat(errorMessage).isEqualTo( + "Response missing structured content which is expected when calling tool with non-empty outputSchema"); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testStructuredOutputRuntimeToolAddition(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + // Start server without tools + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Initially no tools + assertThat(mcpClient.listTools().tools()).isEmpty(); + + // Add tool with output schema at runtime + Map outputSchema = Map.of("type", "object", "properties", + Map.of("message", Map.of("type", "string"), "count", Map.of("type", "integer")), "required", + List.of("message", "count")); + + Tool dynamicTool = Tool.builder() + .name("dynamic-tool") + .description("Dynamically added tool") + .outputSchema(outputSchema) + .build(); + + var toolSpec = McpStatelessServerFeatures.SyncToolSpecification.builder() + .tool(dynamicTool) + .callHandler((exchange, request) -> { + int count = (Integer) request.arguments().getOrDefault("count", 1); + return CallToolResult.builder() + .addTextContent("Dynamic tool executed " + count + " times") + .structuredContent(Map.of("message", "Dynamic execution", "count", count)) + .build(); + }) + .build(); + + // Add tool to server + mcpServer.addTool(toolSpec); + + // Wait for tool list change notification + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(mcpClient.listTools().tools()).hasSize(1); + }); + + // Verify tool was added with output schema + var toolsList = mcpClient.listTools(); + assertThat(toolsList.tools()).hasSize(1); + assertThat(toolsList.tools().get(0).name()).isEqualTo("dynamic-tool"); + // Note: outputSchema might be null in sync server, but validation still works + + // Call dynamically added tool + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("dynamic-tool", Map.of("count", 3))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isFalse(); + + assertThat(response.content()).hasSize(1); + assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) response.content().get(0)).text()) + .isEqualTo("Dynamic tool executed 3 times"); + + assertThat(response.structuredContent()).isNotNull(); + assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"count":3,"message":"Dynamic execution"}""")); + } + finally { + mcpServer.closeGracefully(); + } + } + + private double evaluateExpression(String expression) { + // Simple expression evaluator for testing + return switch (expression) { + case "2 + 3" -> 5.0; + case "10 * 2" -> 20.0; + case "7 + 8" -> 15.0; + case "5 + 3" -> 8.0; + default -> 0.0; + }; + } + +} diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/MockMcpTransport.java b/mcp-test/src/main/java/io/modelcontextprotocol/MockMcpTransport.java index 5484a63c2..cd8458311 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/MockMcpTransport.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/MockMcpTransport.java @@ -9,8 +9,8 @@ import java.util.function.BiConsumer; import java.util.function.Function; -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.TypeRef; import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.JSONRPCNotification; @@ -93,8 +93,8 @@ public Mono closeGracefully() { } @Override - public T unmarshalFrom(Object data, TypeReference typeRef) { - return new ObjectMapper().convertValue(data, typeRef); + public T unmarshalFrom(Object data, TypeRef typeRef) { + return McpJsonMapper.getDefault().convertValue(data, typeRef); } } diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java index 6748eb75c..338eaf931 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java @@ -1,6 +1,7 @@ /* * Copyright 2024-2024 the original author or authors. */ + package io.modelcontextprotocol.client; import eu.rekawek.toxiproxy.Proxy; @@ -9,6 +10,7 @@ import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpTransport; +import io.modelcontextprotocol.spec.McpTransportSessionClosedException; import org.junit.jupiter.api.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -45,10 +47,9 @@ public abstract class AbstractMcpAsyncClientResiliencyTests { static Network network = Network.newNetwork(); static String host = "http://localhost:3001"; - // Uses the https://github.com/tzolov/mcp-everything-server-docker-image @SuppressWarnings("resource") - static GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") - .withCommand("node dist/index.js streamableHttp") + static GenericContainer container = new GenericContainer<>("docker.io/node:lts-alpine3.23") + .withCommand("npx -y @modelcontextprotocol/server-everything@2025.12.18 streamableHttp") .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) .withNetwork(network) .withNetworkAliases("everything-server") @@ -79,7 +80,7 @@ public abstract class AbstractMcpAsyncClientResiliencyTests { host = "http://" + ipAddressViaToxiproxy + ":" + portViaToxiproxy; } - private static void disconnect() { + static void disconnect() { long start = System.nanoTime(); try { // disconnect @@ -96,7 +97,7 @@ private static void disconnect() { } } - private static void reconnect() { + static void reconnect() { long start = System.nanoTime(); try { proxy.toxics().get("RESET_UPSTREAM").remove(); @@ -110,7 +111,7 @@ private static void reconnect() { } } - private static void restartMcpServer() { + static void restartMcpServer() { container.stop(); container.start(); } @@ -133,10 +134,13 @@ McpAsyncClient client(McpClientTransport transport, Function client = new AtomicReference<>(); assertThatCode(() -> { + // Do not advertise roots. Otherwise, the server will list roots during + // initialization. The client responds asynchronously, and there might be a + // rest condition in tests where we disconnect right after initialization. McpClient.AsyncSpec builder = McpClient.async(transport) .requestTimeout(getRequestTimeout()) .initializationTimeout(getInitializationTimeout()) - .capabilities(McpSchema.ClientCapabilities.builder().roots(true).build()); + .capabilities(McpSchema.ClientCapabilities.builder().build()); builder = customizer.apply(builder); client.set(builder.build()); }).doesNotThrowAnyException(); @@ -216,9 +220,10 @@ void testSessionClose() { // In case of Streamable HTTP this call should issue a HTTP DELETE request // invalidating the session StepVerifier.create(mcpAsyncClient.closeGracefully()).expectComplete().verify(); - // The next use should immediately re-initialize with no issue and send the - // request without any broken connections. - StepVerifier.create(mcpAsyncClient.ping()).expectNextCount(1).verifyComplete(); + // The next tries to use the closed session and fails + StepVerifier.create(mcpAsyncClient.ping()) + .expectErrorMatches(err -> err.getCause() instanceof McpTransportSessionClosedException) + .verify(); }); } diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java index 4f7d5678b..bee8f4f16 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java @@ -4,6 +4,7 @@ package io.modelcontextprotocol.client; +import static io.modelcontextprotocol.util.McpJsonMapperUtils.JSON_MAPPER; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -12,16 +13,16 @@ import java.time.Duration; import java.util.ArrayList; +import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import java.util.function.Function; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; @@ -49,6 +50,7 @@ import io.modelcontextprotocol.spec.McpTransport; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; import reactor.test.StepVerifier; /** @@ -64,18 +66,12 @@ public abstract class AbstractMcpAsyncClientTests { abstract protected McpClientTransport createMcpTransport(); - protected void onStart() { - } - - protected void onClose() { - } - protected Duration getRequestTimeout() { return Duration.ofSeconds(14); } protected Duration getInitializationTimeout() { - return Duration.ofSeconds(2); + return Duration.ofSeconds(20); } McpAsyncClient client(McpClientTransport transport) { @@ -114,16 +110,6 @@ void withClient(McpClientTransport transport, Function void verifyNotificationSucceedsWithImplicitInitialization(Function> operation, String action) { withClient(createMcpTransport(), mcpAsyncClient -> { @@ -182,6 +168,25 @@ void testListAllTools() { }); } + @Test + void testListAllToolsReturnsImmutableList() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listTools())) + .consumeNextWith(result -> { + assertThat(result.tools()).isNotNull(); + // Verify that the returned list is immutable + assertThatThrownBy(() -> result.tools() + .add(Tool.builder() + .name("test") + .title("test") + .inputSchema(JSON_MAPPER, "{\"type\":\"object\"}") + .build())) + .isInstanceOf(UnsupportedOperationException.class); + }) + .verifyComplete(); + }); + } + @Test void testPingWithoutInitialization() { verifyCallSucceedsWithImplicitInitialization(client -> client.ping(), "pinging the server"); @@ -333,6 +338,21 @@ void testListAllResources() { }); } + @Test + void testListAllResourcesReturnsImmutableList() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResources())) + .consumeNextWith(result -> { + assertThat(result.resources()).isNotNull(); + // Verify that the returned list is immutable + assertThatThrownBy( + () -> result.resources().add(Resource.builder().uri("test://uri").name("test").build())) + .isInstanceOf(UnsupportedOperationException.class); + }) + .verifyComplete(); + }); + } + @Test void testMcpAsyncClientState() { withClient(createMcpTransport(), mcpAsyncClient -> { @@ -384,6 +404,20 @@ void testListAllPrompts() { }); } + @Test + void testListAllPromptsReturnsImmutableList() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listPrompts())) + .consumeNextWith(result -> { + assertThat(result.prompts()).isNotNull(); + // Verify that the returned list is immutable + assertThatThrownBy(() -> result.prompts().add(new Prompt("test", "test", "test", null))) + .isInstanceOf(UnsupportedOperationException.class); + }) + .verifyComplete(); + }); + } + @Test void testGetPromptWithoutInitialization() { GetPromptRequest request = new GetPromptRequest("simple_prompt", Map.of()); @@ -440,7 +474,8 @@ void testAddRoot() { void testAddRootWithNullValue() { withClient(createMcpTransport(), mcpAsyncClient -> { StepVerifier.create(mcpAsyncClient.addRoot(null)) - .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class).hasMessage("Root must not be null")) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(IllegalArgumentException.class) + .hasMessage("Root must not be null")) .verify(); }); } @@ -459,7 +494,7 @@ void testRemoveRoot() { void testRemoveNonExistentRoot() { withClient(createMcpTransport(), mcpAsyncClient -> { StepVerifier.create(mcpAsyncClient.removeRoot("nonexistent-uri")) - .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(IllegalStateException.class) .hasMessage("Root with uri 'nonexistent-uri' not found")) .verify(); }); @@ -467,57 +502,64 @@ void testRemoveNonExistentRoot() { @Test void testReadResource() { + AtomicInteger resourceCount = new AtomicInteger(); withClient(createMcpTransport(), client -> { Flux resources = client.initialize() .then(client.listResources(null)) - .flatMapMany(r -> Flux.fromIterable(r.resources())) + .flatMapMany(r -> { + List l = r.resources(); + resourceCount.set(l.size()); + return Flux.fromIterable(l); + }) .flatMap(r -> client.readResource(r)); - StepVerifier.create(resources).recordWith(ArrayList::new).consumeRecordedWith(readResourceResults -> { - - for (ReadResourceResult result : readResourceResults) { - - assertThat(result).isNotNull(); - assertThat(result.contents()).isNotNull().isNotEmpty(); - - // Validate each content item - for (ResourceContents content : result.contents()) { - assertThat(content).isNotNull(); - assertThat(content.uri()).isNotNull().isNotEmpty(); - assertThat(content.mimeType()).isNotNull().isNotEmpty(); - - // Validate content based on its type with more comprehensive - // checks - switch (content.mimeType()) { - case "text/plain" -> { - TextResourceContents textContent = assertInstanceOf(TextResourceContents.class, - content); - assertThat(textContent.text()).isNotNull().isNotEmpty(); - assertThat(textContent.uri()).isNotEmpty(); - } - case "application/octet-stream" -> { - BlobResourceContents blobContent = assertInstanceOf(BlobResourceContents.class, - content); - assertThat(blobContent.blob()).isNotNull().isNotEmpty(); - assertThat(blobContent.uri()).isNotNull().isNotEmpty(); - // Validate base64 encoding format - assertThat(blobContent.blob()).matches("^[A-Za-z0-9+/]*={0,2}$"); - } - default -> { - - // Still validate basic properties - if (content instanceof TextResourceContents textContent) { - assertThat(textContent.text()).isNotNull(); + StepVerifier.create(resources) + .recordWith(ArrayList::new) + .thenConsumeWhile(res -> true) + .consumeRecordedWith(readResourceResults -> { + assertThat(readResourceResults.size()).isEqualTo(resourceCount.get()); + for (ReadResourceResult result : readResourceResults) { + + assertThat(result).isNotNull(); + assertThat(result.contents()).isNotNull().isNotEmpty(); + + // Validate each content item + for (ResourceContents content : result.contents()) { + assertThat(content).isNotNull(); + assertThat(content.uri()).isNotNull().isNotEmpty(); + assertThat(content.mimeType()).isNotNull().isNotEmpty(); + + // Validate content based on its type with more comprehensive + // checks + switch (content.mimeType()) { + case "text/plain" -> { + TextResourceContents textContent = assertInstanceOf(TextResourceContents.class, + content); + assertThat(textContent.text()).isNotNull().isNotEmpty(); + assertThat(textContent.uri()).isNotEmpty(); + } + case "application/octet-stream" -> { + BlobResourceContents blobContent = assertInstanceOf(BlobResourceContents.class, + content); + assertThat(blobContent.blob()).isNotNull().isNotEmpty(); + assertThat(blobContent.uri()).isNotNull().isNotEmpty(); + // Validate base64 encoding format + assertThat(blobContent.blob()).matches("^[A-Za-z0-9+/]*={0,2}$"); } - else if (content instanceof BlobResourceContents blobContent) { - assertThat(blobContent.blob()).isNotNull(); + default -> { + + // Still validate basic properties + if (content instanceof TextResourceContents textContent) { + assertThat(textContent.text()).isNotNull(); + } + else if (content instanceof BlobResourceContents blobContent) { + assertThat(blobContent.blob()).isNotNull(); + } } } } } - } - }) - .expectNextCount(10) // Expect 10 elements + }) .verifyComplete(); }); } @@ -553,6 +595,21 @@ void testListAllResourceTemplates() { }); } + @Test + void testListAllResourceTemplatesReturnsImmutableList() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResourceTemplates())) + .consumeNextWith(result -> { + assertThat(result.resourceTemplates()).isNotNull(); + // Verify that the returned list is immutable + assertThatThrownBy(() -> result.resourceTemplates() + .add(new McpSchema.ResourceTemplate("test://template", "test", "test", null, null, null))) + .isInstanceOf(UnsupportedOperationException.class); + }) + .verifyComplete(); + }); + } + // @Test void testResourceSubscription() { withClient(createMcpTransport(), mcpAsyncClient -> { @@ -622,7 +679,7 @@ void testInitializeWithElicitationCapability() { @Test void testInitializeWithAllCapabilities() { var capabilities = ClientCapabilities.builder() - .experimental(Map.of("feature", "test")) + .experimental(Map.of("feature", Map.of("featureFlag", true))) .roots(true) .sampling() .build(); @@ -642,7 +699,6 @@ void testInitializeWithAllCapabilities() { assertThat(result.capabilities()).isNotNull(); }).verifyComplete()); } - // --------------------------------------- // Logging Tests // --------------------------------------- @@ -722,7 +778,7 @@ void testSampling() { if (!(content instanceof McpSchema.TextContent text)) return; - assertThat(text.text()).endsWith(response); // Prefixed + assertThat(text.text()).contains(response); }); // Verify sampling request parameters received in our callback @@ -734,4 +790,39 @@ void testSampling() { }); } + // --------------------------------------- + // Progress Notification Tests + // --------------------------------------- + + @Test + void testProgressConsumer() { + Sinks.Many sink = Sinks.many().unicast().onBackpressureBuffer(); + List receivedNotifications = new CopyOnWriteArrayList<>(); + + withClient(createMcpTransport(), builder -> builder.progressConsumer(notification -> { + receivedNotifications.add(notification); + sink.tryEmitNext(notification); + return Mono.empty(); + }), client -> { + StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete(); + + // Call a tool that sends progress notifications + CallToolRequest request = CallToolRequest.builder() + .name("longRunningOperation") + .arguments(Map.of("duration", 1, "steps", 2)) + .progressToken("test-token") + .build(); + + StepVerifier.create(client.callTool(request)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + }).verifyComplete(); + + // Use StepVerifier to verify the progress notifications via the sink + StepVerifier.create(sink.asFlux()).expectNextCount(2).thenCancel().verify(Duration.ofSeconds(3)); + + assertThat(receivedNotifications).hasSize(2); + assertThat(receivedNotifications.get(0).progressToken()).isEqualTo("test-token"); + }); + } + } diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java index 7736c233c..26d60568a 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java @@ -13,14 +13,15 @@ import java.time.Duration; import java.util.List; import java.util.Map; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import java.util.function.Function; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; @@ -63,12 +64,6 @@ public abstract class AbstractMcpSyncClientTests { abstract protected McpClientTransport createMcpTransport(); - protected void onStart() { - } - - protected void onClose() { - } - protected Duration getRequestTimeout() { return Duration.ofSeconds(14); } @@ -111,17 +106,6 @@ void withClient(McpClientTransport transport, Function void verifyNotificationSucceedsWithImplicitInitialization(Consumer operation, String action) { @@ -551,11 +535,13 @@ void testNotificationHandlers() { AtomicBoolean toolsNotificationReceived = new AtomicBoolean(false); AtomicBoolean resourcesNotificationReceived = new AtomicBoolean(false); AtomicBoolean promptsNotificationReceived = new AtomicBoolean(false); + AtomicBoolean resourcesUpdatedNotificationReceived = new AtomicBoolean(false); withClient(createMcpTransport(), builder -> builder.toolsChangeConsumer(tools -> toolsNotificationReceived.set(true)) .resourcesChangeConsumer(resources -> resourcesNotificationReceived.set(true)) - .promptsChangeConsumer(prompts -> promptsNotificationReceived.set(true)), + .promptsChangeConsumer(prompts -> promptsNotificationReceived.set(true)) + .resourcesUpdateConsumer(resources -> resourcesUpdatedNotificationReceived.set(true)), client -> { assertThatCode(() -> { @@ -638,7 +624,7 @@ void testSampling() { if (!(content instanceof McpSchema.TextContent text)) return; - assertThat(text.text()).endsWith(response); // Prefixed + assertThat(text.text()).contains(response); }); // Verify sampling request parameters received in our callback @@ -648,4 +634,48 @@ void testSampling() { }); } + // --------------------------------------- + // Progress Notification Tests + // --------------------------------------- + + @Test + void testProgressConsumer() { + AtomicInteger progressNotificationCount = new AtomicInteger(0); + List receivedNotifications = new CopyOnWriteArrayList<>(); + CountDownLatch latch = new CountDownLatch(2); + + withClient(createMcpTransport(), builder -> builder.progressConsumer(notification -> { + System.out.println("Received progress notification: " + notification); + receivedNotifications.add(notification); + progressNotificationCount.incrementAndGet(); + latch.countDown(); + }), client -> { + client.initialize(); + + // Call a tool that sends progress notifications + CallToolRequest request = CallToolRequest.builder() + .name("longRunningOperation") + .arguments(Map.of("duration", 1, "steps", 2)) + .progressToken("test-token") + .build(); + + CallToolResult result = client.callTool(request); + + assertThat(result).isNotNull(); + + try { + // Wait for progress notifications to be processed + latch.await(3, TimeUnit.SECONDS); + } + catch (InterruptedException e) { + e.printStackTrace(); + } + + assertThat(progressNotificationCount.get()).isEqualTo(2); + + assertThat(receivedNotifications).isNotEmpty(); + assertThat(receivedNotifications.get(0).progressToken()).isEqualTo("test-token"); + }); + } + } diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java index 12827f469..d6677ec9a 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java @@ -7,7 +7,6 @@ import java.time.Duration; import java.util.List; -import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; @@ -21,9 +20,12 @@ import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; +import static io.modelcontextprotocol.util.ToolsUtils.EMPTY_JSON_SCHEMA; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -34,7 +36,6 @@ * * @author Christian Tzolov */ -// KEEP IN SYNC with the class in mcp-test module public abstract class AbstractMcpAsyncServerTests { private static final String TEST_TOOL_NAME = "test-tool"; @@ -43,7 +44,7 @@ public abstract class AbstractMcpAsyncServerTests { private static final String TEST_PROMPT_NAME = "test-prompt"; - abstract protected McpServerTransportProvider createMcpTransportProvider(); + abstract protected McpServer.AsyncSpecification prepareAsyncServerBuilder(); protected void onStart() { } @@ -64,87 +65,210 @@ void tearDown() { // Server Lifecycle Tests // --------------------------------------- - @Test - void testConstructorWithInvalidArguments() { + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "sse", "streamable" }) + void testConstructorWithInvalidArguments(String serverType) { assertThatThrownBy(() -> McpServer.async((McpServerTransportProvider) null)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("Transport provider must not be null"); - assertThatThrownBy( - () -> McpServer.async(createMcpTransportProvider()).serverInfo((McpSchema.Implementation) null)) + assertThatThrownBy(() -> prepareAsyncServerBuilder().serverInfo((McpSchema.Implementation) null)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("Server info must not be null"); } @Test void testGracefulShutdown() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + McpServer.AsyncSpecification builder = prepareAsyncServerBuilder(); + var mcpAsyncServer = builder.serverInfo("test-server", "1.0.0").build(); StepVerifier.create(mcpAsyncServer.closeGracefully()).verifyComplete(); } @Test void testImmediateClose() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - assertThatCode(() -> mcpAsyncServer.close()).doesNotThrowAnyException(); + assertThatCode(mcpAsyncServer::close).doesNotThrowAnyException(); } // --------------------------------------- // Tools Tests // --------------------------------------- - String emptyJsonSchema = """ - { - "$schema": "http://json-schema.org/draft-07/schema#", - "type": "object", - "properties": {} - } - """; - @Test + @Deprecated void testAddTool() { - Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema); - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + Tool newTool = McpSchema.Tool.builder() + .name("new-tool") + .title("New test tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); - StepVerifier.create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolSpecification(newTool, - (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))))) + StepVerifier + .create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolSpecification(newTool, + (exchange, args) -> Mono.just(CallToolResult.builder().content(List.of()).isError(false).build())))) .verifyComplete(); assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); } @Test + void testAddToolCall() { + Tool newTool = McpSchema.Tool.builder() + .name("new-tool") + .title("New test tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .build(); + + StepVerifier.create(mcpAsyncServer.addTool(McpServerFeatures.AsyncToolSpecification.builder() + .tool(newTool) + .callHandler((exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) + .build())).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + @Deprecated void testAddDuplicateTool() { - Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); + Tool duplicateTool = McpSchema.Tool.builder() + .name(TEST_TOOL_NAME) + .title("Duplicate tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(duplicateTool, (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))) + .tool(duplicateTool, + (exchange, args) -> Mono.just(CallToolResult.builder().content(List.of()).isError(false).build())) .build(); StepVerifier .create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolSpecification(duplicateTool, - (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))))) - .verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Tool with name '" + TEST_TOOL_NAME + "' already exists"); - }); + (exchange, args) -> Mono.just(CallToolResult.builder().content(List.of()).isError(false).build())))) + .verifyComplete(); assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); } + @Test + void testAddDuplicateToolCall() { + Tool duplicateTool = McpSchema.Tool.builder() + .name(TEST_TOOL_NAME) + .title("Duplicate tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .toolCall(duplicateTool, + (exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) + .build(); + + StepVerifier.create(mcpAsyncServer.addTool(McpServerFeatures.AsyncToolSpecification.builder() + .tool(duplicateTool) + .callHandler((exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) + .build())).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testDuplicateToolCallDuringBuilding() { + Tool duplicateTool = McpSchema.Tool.builder() + .name("duplicate-build-toolcall") + .title("Duplicate toolcall during building") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + + assertThatThrownBy(() -> prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .toolCall(duplicateTool, + (exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) + .toolCall(duplicateTool, + (exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) // Duplicate! + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessage("Tool with name 'duplicate-build-toolcall' is already registered."); + } + + @Test + void testDuplicateToolsInBatchListRegistration() { + Tool duplicateTool = McpSchema.Tool.builder() + .name("batch-list-tool") + .title("Duplicate tool in batch list") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + + List specs = List.of( + McpServerFeatures.AsyncToolSpecification.builder() + .tool(duplicateTool) + .callHandler((exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) + .build(), + McpServerFeatures.AsyncToolSpecification.builder() + .tool(duplicateTool) + .callHandler((exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) + .build() // Duplicate! + ); + + assertThatThrownBy(() -> prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(specs) + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessage("Tool with name 'batch-list-tool' is already registered."); + } + + @Test + void testDuplicateToolsInBatchVarargsRegistration() { + Tool duplicateTool = McpSchema.Tool.builder() + .name("batch-varargs-tool") + .title("Duplicate tool in batch varargs") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + + assertThatThrownBy(() -> prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(McpServerFeatures.AsyncToolSpecification.builder() + .tool(duplicateTool) + .callHandler((exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) + .build(), + McpServerFeatures.AsyncToolSpecification.builder() + .tool(duplicateTool) + .callHandler((exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) + .build() // Duplicate! + ) + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessage("Tool with name 'batch-varargs-tool' is already registered."); + } + @Test void testRemoveTool() { - Tool too = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); + Tool too = McpSchema.Tool.builder() + .name(TEST_TOOL_NAME) + .title("Duplicate tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(too, (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))) + .toolCall(too, + (exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) .build(); StepVerifier.create(mcpAsyncServer.removeTool(TEST_TOOL_NAME)).verifyComplete(); @@ -154,26 +278,27 @@ void testRemoveTool() { @Test void testRemoveNonexistentTool() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); - StepVerifier.create(mcpAsyncServer.removeTool("nonexistent-tool")).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class).hasMessage("Tool with name 'nonexistent-tool' not found"); - }); + StepVerifier.create(mcpAsyncServer.removeTool("nonexistent-tool")).verifyComplete(); assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); } @Test void testNotifyToolsListChanged() { - Tool too = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); + Tool too = McpSchema.Tool.builder() + .name(TEST_TOOL_NAME) + .title("Duplicate tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(too, (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))) + .toolCall(too, + (exchange, args) -> Mono.just(CallToolResult.builder().content(List.of()).isError(false).build())) .build(); StepVerifier.create(mcpAsyncServer.notifyToolsListChanged()).verifyComplete(); @@ -187,7 +312,7 @@ void testNotifyToolsListChanged() { @Test void testNotifyResourcesListChanged() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); StepVerifier.create(mcpAsyncServer.notifyResourcesListChanged()).verifyComplete(); @@ -196,7 +321,7 @@ void testNotifyResourcesListChanged() { @Test void testNotifyResourcesUpdated() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); StepVerifier .create(mcpAsyncServer @@ -208,13 +333,17 @@ void testNotifyResourcesUpdated() { @Test void testAddResource() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().resources(true, false).build()) .build(); - Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", - null); + Resource resource = Resource.builder() + .uri(TEST_RESOURCE_URI) + .name("Test Resource") + .title("Test Resource") + .mimeType("text/plain") + .description("Test resource description") + .build(); McpServerFeatures.AsyncResourceSpecification specification = new McpServerFeatures.AsyncResourceSpecification( resource, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); @@ -225,14 +354,13 @@ void testAddResource() { @Test void testAddResourceWithNullSpecification() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().resources(true, false).build()) .build(); StepVerifier.create(mcpAsyncServer.addResource((McpServerFeatures.AsyncResourceSpecification) null)) .verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class).hasMessage("Resource must not be null"); + assertThat(error).isInstanceOf(IllegalArgumentException.class).hasMessage("Resource must not be null"); }); assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); @@ -241,41 +369,222 @@ void testAddResourceWithNullSpecification() { @Test void testAddResourceWithoutCapability() { // Create a server without resource capabilities - McpAsyncServer serverWithoutResources = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + McpAsyncServer serverWithoutResources = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + Resource resource = Resource.builder() + .uri(TEST_RESOURCE_URI) + .name("Test Resource") + .title("Test Resource") + .mimeType("text/plain") + .description("Test resource description") .build(); - - Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", - null); McpServerFeatures.AsyncResourceSpecification specification = new McpServerFeatures.AsyncResourceSpecification( resource, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); StepVerifier.create(serverWithoutResources.addResource(specification)).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Server must be configured with resource capabilities"); + assertThat(error).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Server must be configured with resource capabilities"); }); } @Test void testRemoveResourceWithoutCapability() { // Create a server without resource capabilities - McpAsyncServer serverWithoutResources = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .build(); + McpAsyncServer serverWithoutResources = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); StepVerifier.create(serverWithoutResources.removeResource(TEST_RESOURCE_URI)).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Server must be configured with resource capabilities"); + assertThat(error).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Server must be configured with resource capabilities"); }); } + @Test + void testListResources() { + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + Resource resource = Resource.builder() + .uri(TEST_RESOURCE_URI) + .name("Test Resource") + .title("Test Resource") + .mimeType("text/plain") + .description("Test resource description") + .build(); + McpServerFeatures.AsyncResourceSpecification specification = new McpServerFeatures.AsyncResourceSpecification( + resource, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + StepVerifier + .create(mcpAsyncServer.addResource(specification).then(mcpAsyncServer.listResources().collectList())) + .expectNextMatches(resources -> resources.size() == 1 && resources.get(0).uri().equals(TEST_RESOURCE_URI)) + .verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testRemoveResource() { + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + Resource resource = Resource.builder() + .uri(TEST_RESOURCE_URI) + .name("Test Resource") + .title("Test Resource") + .mimeType("text/plain") + .description("Test resource description") + .build(); + McpServerFeatures.AsyncResourceSpecification specification = new McpServerFeatures.AsyncResourceSpecification( + resource, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + StepVerifier + .create(mcpAsyncServer.addResource(specification).then(mcpAsyncServer.removeResource(TEST_RESOURCE_URI))) + .verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testRemoveNonexistentResource() { + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + // Removing a non-existent resource should complete successfully (no error) + // as per the new implementation that just logs a warning + StepVerifier.create(mcpAsyncServer.removeResource("nonexistent://resource")).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Resource Template Tests + // --------------------------------------- + + @Test + void testAddResourceTemplate() { + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + McpSchema.ResourceTemplate template = McpSchema.ResourceTemplate.builder() + .uriTemplate("test://template/{id}") + .name("test-template") + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.AsyncResourceTemplateSpecification specification = new McpServerFeatures.AsyncResourceTemplateSpecification( + template, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + StepVerifier.create(mcpAsyncServer.addResourceTemplate(specification)).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testAddResourceTemplateWithoutCapability() { + // Create a server without resource capabilities + McpAsyncServer serverWithoutResources = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + McpSchema.ResourceTemplate template = McpSchema.ResourceTemplate.builder() + .uriTemplate("test://template/{id}") + .name("test-template") + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.AsyncResourceTemplateSpecification specification = new McpServerFeatures.AsyncResourceTemplateSpecification( + template, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + StepVerifier.create(serverWithoutResources.addResourceTemplate(specification)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Server must be configured with resource capabilities"); + }); + } + + @Test + void testRemoveResourceTemplate() { + McpSchema.ResourceTemplate template = McpSchema.ResourceTemplate.builder() + .uriTemplate("test://template/{id}") + .name("test-template") + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.AsyncResourceTemplateSpecification specification = new McpServerFeatures.AsyncResourceTemplateSpecification( + template, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .resourceTemplates(specification) + .build(); + + StepVerifier.create(mcpAsyncServer.removeResourceTemplate("test://template/{id}")).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testRemoveResourceTemplateWithoutCapability() { + // Create a server without resource capabilities + McpAsyncServer serverWithoutResources = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + StepVerifier.create(serverWithoutResources.removeResourceTemplate("test://template/{id}")) + .verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Server must be configured with resource capabilities"); + }); + } + + @Test + void testRemoveNonexistentResourceTemplate() { + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + StepVerifier.create(mcpAsyncServer.removeResourceTemplate("nonexistent://template/{id}")).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testListResourceTemplates() { + McpSchema.ResourceTemplate template = McpSchema.ResourceTemplate.builder() + .uriTemplate("test://template/{id}") + .name("test-template") + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.AsyncResourceTemplateSpecification specification = new McpServerFeatures.AsyncResourceTemplateSpecification( + template, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .resourceTemplates(specification) + .build(); + + // Note: Based on the current implementation, listResourceTemplates() returns + // Flux + // This appears to be a bug in the implementation that should return + // Flux + StepVerifier.create(mcpAsyncServer.listResourceTemplates().collectList()) + .expectNextMatches(resources -> resources.size() >= 0) // Just verify it + // doesn't error + .verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + // --------------------------------------- // Prompts Tests // --------------------------------------- @Test void testNotifyPromptsListChanged() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); StepVerifier.create(mcpAsyncServer.notifyPromptsListChanged()).verifyComplete(); @@ -284,31 +593,29 @@ void testNotifyPromptsListChanged() { @Test void testAddPromptWithNullSpecification() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(false).build()) .build(); StepVerifier.create(mcpAsyncServer.addPrompt((McpServerFeatures.AsyncPromptSpecification) null)) .verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class).hasMessage("Prompt specification must not be null"); + assertThat(error).isInstanceOf(IllegalArgumentException.class) + .hasMessage("Prompt specification must not be null"); }); } @Test void testAddPromptWithoutCapability() { // Create a server without prompt capabilities - McpAsyncServer serverWithoutPrompts = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .build(); + McpAsyncServer serverWithoutPrompts = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", List.of()); + Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", "Test Prompt", List.of()); McpServerFeatures.AsyncPromptSpecification specification = new McpServerFeatures.AsyncPromptSpecification( prompt, (exchange, req) -> Mono.just(new GetPromptResult("Test prompt description", List .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content")))))); StepVerifier.create(serverWithoutPrompts.addPrompt(specification)).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) + assertThat(error).isInstanceOf(IllegalStateException.class) .hasMessage("Server must be configured with prompt capabilities"); }); } @@ -316,12 +623,10 @@ void testAddPromptWithoutCapability() { @Test void testRemovePromptWithoutCapability() { // Create a server without prompt capabilities - McpAsyncServer serverWithoutPrompts = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .build(); + McpAsyncServer serverWithoutPrompts = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); StepVerifier.create(serverWithoutPrompts.removePrompt(TEST_PROMPT_NAME)).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) + assertThat(error).isInstanceOf(IllegalStateException.class) .hasMessage("Server must be configured with prompt capabilities"); }); } @@ -330,13 +635,12 @@ void testRemovePromptWithoutCapability() { void testRemovePrompt() { String TEST_PROMPT_NAME_TO_REMOVE = "TEST_PROMPT_NAME678"; - Prompt prompt = new Prompt(TEST_PROMPT_NAME_TO_REMOVE, "Test Prompt", List.of()); + Prompt prompt = new Prompt(TEST_PROMPT_NAME_TO_REMOVE, "Test Prompt", "Test Prompt", List.of()); McpServerFeatures.AsyncPromptSpecification specification = new McpServerFeatures.AsyncPromptSpecification( prompt, (exchange, req) -> Mono.just(new GetPromptResult("Test prompt description", List .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content")))))); - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(true).build()) .prompts(specification) .build(); @@ -348,15 +652,11 @@ void testRemovePrompt() { @Test void testRemoveNonexistentPrompt() { - var mcpAsyncServer2 = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer2 = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(true).build()) .build(); - StepVerifier.create(mcpAsyncServer2.removePrompt("nonexistent-prompt")).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Prompt with name 'nonexistent-prompt' not found"); - }); + StepVerifier.create(mcpAsyncServer2.removePrompt("nonexistent-prompt")).verifyComplete(); assertThatCode(() -> mcpAsyncServer2.closeGracefully().block(Duration.ofSeconds(10))) .doesNotThrowAnyException(); @@ -372,8 +672,7 @@ void testRootsChangeHandlers() { var rootsReceived = new McpSchema.Root[1]; var consumerCalled = new boolean[1]; - var singleConsumerServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var singleConsumerServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .rootsChangeHandlers(List.of((exchange, roots) -> Mono.fromRunnable(() -> { consumerCalled[0] = true; if (!roots.isEmpty()) { @@ -392,8 +691,7 @@ void testRootsChangeHandlers() { var consumer2Called = new boolean[1]; var rootsContent = new List[1]; - var multipleConsumersServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var multipleConsumersServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .rootsChangeHandlers(List.of((exchange, roots) -> Mono.fromRunnable(() -> { consumer1Called[0] = true; rootsContent[0] = roots; @@ -406,8 +704,7 @@ void testRootsChangeHandlers() { onClose(); // Test error handling - var errorHandlingServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var errorHandlingServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .rootsChangeHandlers(List.of((exchange, roots) -> { throw new RuntimeException("Test error"); })) @@ -419,9 +716,7 @@ void testRootsChangeHandlers() { onClose(); // Test without consumers - var noConsumersServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .build(); + var noConsumersServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); assertThat(noConsumersServer).isNotNull(); assertThatCode(() -> noConsumersServer.closeGracefully().block(Duration.ofSeconds(10))) diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java index eefcdf9a3..0a59d0aae 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java @@ -6,7 +6,6 @@ import java.util.List; -import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; @@ -21,17 +20,17 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import static io.modelcontextprotocol.util.ToolsUtils.EMPTY_JSON_SCHEMA; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; import static org.assertj.core.api.Assertions.assertThatThrownBy; /** * Test suite for the {@link McpSyncServer} that can be used with different - * {@link io.modelcontextprotocol.spec.McpServerTransportProvider} implementations. + * {@link McpServerTransportProvider} implementations. * * @author Christian Tzolov */ -// KEEP IN SYNC with the class in mcp-test module public abstract class AbstractMcpSyncServerTests { private static final String TEST_TOOL_NAME = "test-tool"; @@ -40,7 +39,7 @@ public abstract class AbstractMcpSyncServerTests { private static final String TEST_PROMPT_NAME = "test-prompt"; - abstract protected McpServerTransportProvider createMcpTransportProvider(); + abstract protected McpServer.SyncSpecification prepareSyncServerBuilder(); protected void onStart() { } @@ -68,114 +67,225 @@ void testConstructorWithInvalidArguments() { .isInstanceOf(IllegalArgumentException.class) .hasMessage("Transport provider must not be null"); - assertThatThrownBy(() -> McpServer.sync(createMcpTransportProvider()).serverInfo(null)) + assertThatThrownBy(() -> prepareSyncServerBuilder().serverInfo(null)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("Server info must not be null"); } @Test void testGracefulShutdown() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } @Test void testImmediateClose() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - assertThatCode(() -> mcpSyncServer.close()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::close).doesNotThrowAnyException(); } @Test void testGetAsyncServer() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); assertThat(mcpSyncServer.getAsyncServer()).isNotNull(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } // --------------------------------------- // Tools Tests // --------------------------------------- - String emptyJsonSchema = """ - { - "$schema": "http://json-schema.org/draft-07/schema#", - "type": "object", - "properties": {} - } - """; - @Test + @Deprecated void testAddTool() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); - Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema); + Tool newTool = McpSchema.Tool.builder() + .name("new-tool") + .title("New test tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); assertThatCode(() -> mcpSyncServer.addTool(new McpServerFeatures.SyncToolSpecification(newTool, - (exchange, args) -> new CallToolResult(List.of(), false)))) + (exchange, args) -> CallToolResult.builder().content(List.of()).isError(false).build()))) .doesNotThrowAnyException(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testAddToolCall() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .build(); + + Tool newTool = McpSchema.Tool.builder() + .name("new-tool") + .title("New test tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + + assertThatCode(() -> mcpSyncServer.addTool(McpServerFeatures.SyncToolSpecification.builder() + .tool(newTool) + .callHandler((exchange, request) -> CallToolResult.builder().content(List.of()).isError(false).build()) + .build())).doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } @Test + @Deprecated void testAddDuplicateTool() { - Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); + Tool duplicateTool = McpSchema.Tool.builder() + .name(TEST_TOOL_NAME) + .title("Duplicate tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tool(duplicateTool, (exchange, args) -> CallToolResult.builder().content(List.of()).isError(false).build()) + .build(); + + assertThatCode(() -> mcpSyncServer.addTool(new McpServerFeatures.SyncToolSpecification(duplicateTool, + (exchange, args) -> CallToolResult.builder().content(List.of()).isError(false).build()))) + .doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testAddDuplicateToolCall() { + Tool duplicateTool = McpSchema.Tool.builder() + .name(TEST_TOOL_NAME) + .title("Duplicate tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .toolCall(duplicateTool, + (exchange, request) -> CallToolResult.builder().content(List.of()).isError(false).build()) + .build(); + + assertThatCode(() -> mcpSyncServer.addTool(McpServerFeatures.SyncToolSpecification.builder() + .tool(duplicateTool) + .callHandler((exchange, request) -> CallToolResult.builder().content(List.of()).isError(false).build()) + .build())).doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + @Test + void testDuplicateToolCallDuringBuilding() { + Tool duplicateTool = McpSchema.Tool.builder() + .name("duplicate-build-toolcall") + .title("Duplicate toolcall during building") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + + assertThatThrownBy(() -> prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(duplicateTool, (exchange, args) -> new CallToolResult(List.of(), false)) + .toolCall(duplicateTool, + (exchange, request) -> CallToolResult.builder().content(List.of()).isError(false).build()) + .toolCall(duplicateTool, + (exchange, request) -> CallToolResult.builder().content(List.of()).isError(false).build()) // Duplicate! + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessage("Tool with name 'duplicate-build-toolcall' is already registered."); + } + + @Test + void testDuplicateToolsInBatchListRegistration() { + Tool duplicateTool = McpSchema.Tool.builder() + .name("batch-list-tool") + .title("Duplicate tool in batch list") + .inputSchema(EMPTY_JSON_SCHEMA) .build(); + List specs = List.of( + McpServerFeatures.SyncToolSpecification.builder() + .tool(duplicateTool) + .callHandler( + (exchange, request) -> CallToolResult.builder().content(List.of()).isError(false).build()) + .build(), + McpServerFeatures.SyncToolSpecification.builder() + .tool(duplicateTool) + .callHandler( + (exchange, request) -> CallToolResult.builder().content(List.of()).isError(false).build()) + .build() // Duplicate! + ); + + assertThatThrownBy(() -> prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(specs) + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessage("Tool with name 'batch-list-tool' is already registered."); + } - assertThatThrownBy(() -> mcpSyncServer.addTool(new McpServerFeatures.SyncToolSpecification(duplicateTool, - (exchange, args) -> new CallToolResult(List.of(), false)))) - .isInstanceOf(McpError.class) - .hasMessage("Tool with name '" + TEST_TOOL_NAME + "' already exists"); + @Test + void testDuplicateToolsInBatchVarargsRegistration() { + Tool duplicateTool = McpSchema.Tool.builder() + .name("batch-varargs-tool") + .title("Duplicate tool in batch varargs") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatThrownBy(() -> prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(McpServerFeatures.SyncToolSpecification.builder() + .tool(duplicateTool) + .callHandler((exchange, request) -> CallToolResult.builder().content(List.of()).isError(false).build()) + .build(), + McpServerFeatures.SyncToolSpecification.builder() + .tool(duplicateTool) + .callHandler((exchange, + request) -> CallToolResult.builder().content(List.of()).isError(false).build()) + .build() // Duplicate! + ) + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessage("Tool with name 'batch-varargs-tool' is already registered."); } @Test void testRemoveTool() { - Tool tool = new McpSchema.Tool(TEST_TOOL_NAME, "Test tool", emptyJsonSchema); + Tool tool = McpSchema.Tool.builder() + .name(TEST_TOOL_NAME) + .title("Test tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(tool, (exchange, args) -> new CallToolResult(List.of(), false)) + .toolCall(tool, (exchange, args) -> CallToolResult.builder().content(List.of()).isError(false).build()) .build(); assertThatCode(() -> mcpSyncServer.removeTool(TEST_TOOL_NAME)).doesNotThrowAnyException(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } @Test void testRemoveNonexistentTool() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); - assertThatThrownBy(() -> mcpSyncServer.removeTool("nonexistent-tool")).isInstanceOf(McpError.class) - .hasMessage("Tool with name 'nonexistent-tool' not found"); + assertThatCode(() -> mcpSyncServer.removeTool("nonexistent-tool")).doesNotThrowAnyException(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } @Test void testNotifyToolsListChanged() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - assertThatCode(() -> mcpSyncServer.notifyToolsListChanged()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::notifyToolsListChanged).doesNotThrowAnyException(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } // --------------------------------------- @@ -184,78 +294,257 @@ void testNotifyToolsListChanged() { @Test void testNotifyResourcesListChanged() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - assertThatCode(() -> mcpSyncServer.notifyResourcesListChanged()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::notifyResourcesListChanged).doesNotThrowAnyException(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } @Test void testNotifyResourcesUpdated() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); assertThatCode(() -> mcpSyncServer .notifyResourcesUpdated(new McpSchema.ResourcesUpdatedNotification(TEST_RESOURCE_URI))) .doesNotThrowAnyException(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } @Test void testAddResource() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().resources(true, false).build()) .build(); - Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", - null); + Resource resource = Resource.builder() + .uri(TEST_RESOURCE_URI) + .name("Test Resource") + .title("Test Resource") + .mimeType("text/plain") + .description("Test resource description") + .build(); McpServerFeatures.SyncResourceSpecification specification = new McpServerFeatures.SyncResourceSpecification( resource, (exchange, req) -> new ReadResourceResult(List.of())); assertThatCode(() -> mcpSyncServer.addResource(specification)).doesNotThrowAnyException(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } @Test void testAddResourceWithNullSpecification() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().resources(true, false).build()) .build(); assertThatThrownBy(() -> mcpSyncServer.addResource((McpServerFeatures.SyncResourceSpecification) null)) - .isInstanceOf(McpError.class) + .isInstanceOf(IllegalArgumentException.class) .hasMessage("Resource must not be null"); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } @Test void testAddResourceWithoutCapability() { - var serverWithoutResources = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var serverWithoutResources = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + Resource resource = Resource.builder() + .uri(TEST_RESOURCE_URI) + .name("Test Resource") + .title("Test Resource") + .mimeType("text/plain") + .description("Test resource description") .build(); - - Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", - null); McpServerFeatures.SyncResourceSpecification specification = new McpServerFeatures.SyncResourceSpecification( resource, (exchange, req) -> new ReadResourceResult(List.of())); - assertThatThrownBy(() -> serverWithoutResources.addResource(specification)).isInstanceOf(McpError.class) - .hasMessage("Server must be configured with resource capabilities"); + assertThatThrownBy(() -> serverWithoutResources.addResource(specification)) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Server must be configured with resource capabilities"); } @Test void testRemoveResourceWithoutCapability() { - var serverWithoutResources = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var serverWithoutResources = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + assertThatThrownBy(() -> serverWithoutResources.removeResource(TEST_RESOURCE_URI)) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Server must be configured with resource capabilities"); + } + + @Test + void testListResources() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + Resource resource = Resource.builder() + .uri(TEST_RESOURCE_URI) + .name("Test Resource") + .title("Test Resource") + .mimeType("text/plain") + .description("Test resource description") + .build(); + McpServerFeatures.SyncResourceSpecification specification = new McpServerFeatures.SyncResourceSpecification( + resource, (exchange, req) -> new ReadResourceResult(List.of())); + + mcpSyncServer.addResource(specification); + List resources = mcpSyncServer.listResources(); + + assertThat(resources).hasSize(1); + assertThat(resources.get(0).uri()).isEqualTo(TEST_RESOURCE_URI); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testRemoveResource() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + Resource resource = Resource.builder() + .uri(TEST_RESOURCE_URI) + .name("Test Resource") + .title("Test Resource") + .mimeType("text/plain") + .description("Test resource description") + .build(); + McpServerFeatures.SyncResourceSpecification specification = new McpServerFeatures.SyncResourceSpecification( + resource, (exchange, req) -> new ReadResourceResult(List.of())); + + mcpSyncServer.addResource(specification); + assertThatCode(() -> mcpSyncServer.removeResource(TEST_RESOURCE_URI)).doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testRemoveNonexistentResource() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + // Removing a non-existent resource should complete successfully (no error) + // as per the new implementation that just logs a warning + assertThatCode(() -> mcpSyncServer.removeResource("nonexistent://resource")).doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Resource Template Tests + // --------------------------------------- + + @Test + void testAddResourceTemplate() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + McpSchema.ResourceTemplate template = McpSchema.ResourceTemplate.builder() + .uriTemplate("test://template/{id}") + .name("test-template") + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.SyncResourceTemplateSpecification specification = new McpServerFeatures.SyncResourceTemplateSpecification( + template, (exchange, req) -> new ReadResourceResult(List.of())); + + assertThatCode(() -> mcpSyncServer.addResourceTemplate(specification)).doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testAddResourceTemplateWithoutCapability() { + // Create a server without resource capabilities + var serverWithoutResources = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + McpSchema.ResourceTemplate template = McpSchema.ResourceTemplate.builder() + .uriTemplate("test://template/{id}") + .name("test-template") + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.SyncResourceTemplateSpecification specification = new McpServerFeatures.SyncResourceTemplateSpecification( + template, (exchange, req) -> new ReadResourceResult(List.of())); + + assertThatThrownBy(() -> serverWithoutResources.addResourceTemplate(specification)) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Server must be configured with resource capabilities"); + } + + @Test + void testRemoveResourceTemplate() { + McpSchema.ResourceTemplate template = McpSchema.ResourceTemplate.builder() + .uriTemplate("test://template/{id}") + .name("test-template") + .description("Test resource template") + .mimeType("text/plain") .build(); - assertThatThrownBy(() -> serverWithoutResources.removeResource(TEST_RESOURCE_URI)).isInstanceOf(McpError.class) - .hasMessage("Server must be configured with resource capabilities"); + McpServerFeatures.SyncResourceTemplateSpecification specification = new McpServerFeatures.SyncResourceTemplateSpecification( + template, (exchange, req) -> new ReadResourceResult(List.of())); + + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .resourceTemplates(specification) + .build(); + + assertThatCode(() -> mcpSyncServer.removeResourceTemplate("test://template/{id}")).doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testRemoveResourceTemplateWithoutCapability() { + // Create a server without resource capabilities + var serverWithoutResources = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + assertThatThrownBy(() -> serverWithoutResources.removeResourceTemplate("test://template/{id}")) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Server must be configured with resource capabilities"); + } + + @Test + void testRemoveNonexistentResourceTemplate() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + assertThatCode(() -> mcpSyncServer.removeResourceTemplate("nonexistent://template/{id}")) + .doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testListResourceTemplates() { + McpSchema.ResourceTemplate template = McpSchema.ResourceTemplate.builder() + .uriTemplate("test://template/{id}") + .name("test-template") + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.SyncResourceTemplateSpecification specification = new McpServerFeatures.SyncResourceTemplateSpecification( + template, (exchange, req) -> new ReadResourceResult(List.of())); + + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .resourceTemplates(specification) + .build(); + + List templates = mcpSyncServer.listResourceTemplates(); + + assertThat(templates).isNotNull(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } // --------------------------------------- @@ -264,79 +553,73 @@ void testRemoveResourceWithoutCapability() { @Test void testNotifyPromptsListChanged() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - assertThatCode(() -> mcpSyncServer.notifyPromptsListChanged()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::notifyPromptsListChanged).doesNotThrowAnyException(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } @Test void testAddPromptWithNullSpecification() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(false).build()) .build(); assertThatThrownBy(() -> mcpSyncServer.addPrompt((McpServerFeatures.SyncPromptSpecification) null)) - .isInstanceOf(McpError.class) + .isInstanceOf(IllegalArgumentException.class) .hasMessage("Prompt specification must not be null"); } @Test void testAddPromptWithoutCapability() { - var serverWithoutPrompts = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .build(); + var serverWithoutPrompts = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", List.of()); + Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", "Test Prompt", List.of()); McpServerFeatures.SyncPromptSpecification specification = new McpServerFeatures.SyncPromptSpecification(prompt, (exchange, req) -> new GetPromptResult("Test prompt description", List .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content"))))); - assertThatThrownBy(() -> serverWithoutPrompts.addPrompt(specification)).isInstanceOf(McpError.class) + assertThatThrownBy(() -> serverWithoutPrompts.addPrompt(specification)) + .isInstanceOf(IllegalStateException.class) .hasMessage("Server must be configured with prompt capabilities"); } @Test void testRemovePromptWithoutCapability() { - var serverWithoutPrompts = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .build(); + var serverWithoutPrompts = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - assertThatThrownBy(() -> serverWithoutPrompts.removePrompt(TEST_PROMPT_NAME)).isInstanceOf(McpError.class) + assertThatThrownBy(() -> serverWithoutPrompts.removePrompt(TEST_PROMPT_NAME)) + .isInstanceOf(IllegalStateException.class) .hasMessage("Server must be configured with prompt capabilities"); } @Test void testRemovePrompt() { - Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", List.of()); + Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", "Test Prompt", List.of()); McpServerFeatures.SyncPromptSpecification specification = new McpServerFeatures.SyncPromptSpecification(prompt, (exchange, req) -> new GetPromptResult("Test prompt description", List .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content"))))); - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(true).build()) .prompts(specification) .build(); assertThatCode(() -> mcpSyncServer.removePrompt(TEST_PROMPT_NAME)).doesNotThrowAnyException(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } @Test void testRemoveNonexistentPrompt() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(true).build()) .build(); - assertThatThrownBy(() -> mcpSyncServer.removePrompt("nonexistent-prompt")).isInstanceOf(McpError.class) - .hasMessage("Prompt with name 'nonexistent-prompt' not found"); + assertThatCode(() -> mcpSyncServer.removePrompt("nonexistent://template/{id}")).doesNotThrowAnyException(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } // --------------------------------------- @@ -349,8 +632,7 @@ void testRootsChangeHandlers() { var rootsReceived = new McpSchema.Root[1]; var consumerCalled = new boolean[1]; - var singleConsumerServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var singleConsumerServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .rootsChangeHandlers(List.of((exchange, roots) -> { consumerCalled[0] = true; if (!roots.isEmpty()) { @@ -358,9 +640,8 @@ void testRootsChangeHandlers() { } })) .build(); - assertThat(singleConsumerServer).isNotNull(); - assertThatCode(() -> singleConsumerServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(singleConsumerServer::closeGracefully).doesNotThrowAnyException(); onClose(); // Test with multiple consumers @@ -368,8 +649,7 @@ void testRootsChangeHandlers() { var consumer2Called = new boolean[1]; var rootsContent = new List[1]; - var multipleConsumersServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var multipleConsumersServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .rootsChangeHandlers(List.of((exchange, roots) -> { consumer1Called[0] = true; rootsContent[0] = roots; @@ -377,26 +657,25 @@ void testRootsChangeHandlers() { .build(); assertThat(multipleConsumersServer).isNotNull(); - assertThatCode(() -> multipleConsumersServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(multipleConsumersServer::closeGracefully).doesNotThrowAnyException(); onClose(); // Test error handling - var errorHandlingServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var errorHandlingServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .rootsChangeHandlers(List.of((exchange, roots) -> { throw new RuntimeException("Test error"); })) .build(); assertThat(errorHandlingServer).isNotNull(); - assertThatCode(() -> errorHandlingServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(errorHandlingServer::closeGracefully).doesNotThrowAnyException(); onClose(); // Test without consumers - var noConsumersServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + var noConsumersServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); assertThat(noConsumersServer).isNotNull(); - assertThatCode(() -> noConsumersServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(noConsumersServer::closeGracefully).doesNotThrowAnyException(); } } diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/server/TestUtil.java b/mcp-test/src/main/java/io/modelcontextprotocol/server/TestUtil.java index 0085f31ed..dbbf1a537 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/server/TestUtil.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/server/TestUtil.java @@ -1,6 +1,7 @@ /* * Copyright 2025 - 2025 the original author or authors. */ + package io.modelcontextprotocol.server; import java.io.IOException; diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/util/McpJsonMapperUtils.java b/mcp-test/src/main/java/io/modelcontextprotocol/util/McpJsonMapperUtils.java new file mode 100644 index 000000000..723965519 --- /dev/null +++ b/mcp-test/src/main/java/io/modelcontextprotocol/util/McpJsonMapperUtils.java @@ -0,0 +1,12 @@ +package io.modelcontextprotocol.util; + +import io.modelcontextprotocol.json.McpJsonMapper; + +public final class McpJsonMapperUtils { + + private McpJsonMapperUtils() { + } + + public static final McpJsonMapper JSON_MAPPER = McpJsonMapper.getDefault(); + +} \ No newline at end of file diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/util/ToolsUtils.java b/mcp-test/src/main/java/io/modelcontextprotocol/util/ToolsUtils.java new file mode 100644 index 000000000..ce8755223 --- /dev/null +++ b/mcp-test/src/main/java/io/modelcontextprotocol/util/ToolsUtils.java @@ -0,0 +1,15 @@ +package io.modelcontextprotocol.util; + +import io.modelcontextprotocol.spec.McpSchema; + +import java.util.Collections; + +public final class ToolsUtils { + + private ToolsUtils() { + } + + public static final McpSchema.JsonSchema EMPTY_JSON_SCHEMA = new McpSchema.JsonSchema("object", + Collections.emptyMap(), null, null, null, null); + +} diff --git a/mcp/pom.xml b/mcp/pom.xml index 773432827..0e0ed1288 100644 --- a/mcp/pom.xml +++ b/mcp/pom.xml @@ -6,7 +6,7 @@ io.modelcontextprotocol.sdk mcp-parent - 0.11.0-SNAPSHOT + 0.18.0-SNAPSHOT mcp jar @@ -20,189 +20,20 @@ git@github.com/modelcontextprotocol/java-sdk.git - - - - biz.aQute.bnd - bnd-maven-plugin - ${bnd-maven-plugin.version} - - - bnd-process - - bnd-process - - - - - - - - - - - org.apache.maven.plugins - maven-jar-plugin - - - ${project.build.outputDirectory}/META-INF/MANIFEST.MF - - - - - - - org.slf4j - slf4j-api - ${slf4j-api.version} - - - - com.fasterxml.jackson.core - jackson-databind - ${jackson.version} - - - - io.projectreactor - reactor-core - - - - org.springframework - spring-webmvc - ${springframework.version} - test - - - - - io.projectreactor.netty - reactor-netty-http - test - - - - - org.springframework - spring-context - ${springframework.version} - test - - - - org.springframework - spring-test - ${springframework.version} - test - - - - org.assertj - assertj-core - ${assert4j.version} - test - - - org.junit.jupiter - junit-jupiter-api - ${junit.version} - test - - - org.junit.jupiter - junit-jupiter-params - ${junit.version} - test - - - org.mockito - mockito-core - ${mockito.version} - test - - - - - net.bytebuddy - byte-buddy - ${byte-buddy.version} - test - - - io.projectreactor - reactor-test - test - - - org.testcontainers - junit-jupiter - ${testcontainers.version} - test - - - - org.awaitility - awaitility - ${awaitility.version} - test + io.modelcontextprotocol.sdk + mcp-json-jackson2 + 0.18.0-SNAPSHOT - ch.qos.logback - logback-classic - ${logback.version} - test + io.modelcontextprotocol.sdk + mcp-core + 0.18.0-SNAPSHOT - - - net.javacrumbs.json-unit - json-unit-assertj - ${json-unit-assertj.version} - test - - - - - - jakarta.servlet - jakarta.servlet-api - ${jakarta.servlet.version} - provided - - - - - org.apache.tomcat.embed - tomcat-embed-core - ${tomcat.version} - test - - - org.apache.tomcat.embed - tomcat-embed-websocket - ${tomcat.version} - test - - - \ No newline at end of file + diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/FlowSseClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/FlowSseClient.java deleted file mode 100644 index abfafa551..000000000 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/FlowSseClient.java +++ /dev/null @@ -1,211 +0,0 @@ -/* -* Copyright 2024 - 2024 the original author or authors. -*/ -package io.modelcontextprotocol.client.transport; - -import java.net.URI; -import java.net.http.HttpClient; -import java.net.http.HttpRequest; -import java.net.http.HttpResponse; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.Flow; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Function; -import java.util.regex.Pattern; - -/** - * A Server-Sent Events (SSE) client implementation using Java's Flow API for reactive - * stream processing. This client establishes a connection to an SSE endpoint and - * processes the incoming event stream, parsing SSE-formatted messages into structured - * events. - * - *

- * The client supports standard SSE event fields including: - *

    - *
  • event - The event type (defaults to "message" if not specified)
  • - *
  • id - The event ID
  • - *
  • data - The event payload data
  • - *
- * - *

- * Events are delivered to a provided {@link SseEventHandler} which can process events and - * handle any errors that occur during the connection. - * - * @author Christian Tzolov - * @see SseEventHandler - * @see SseEvent - */ -public class FlowSseClient { - - private final HttpClient httpClient; - - private final HttpRequest.Builder requestBuilder; - - /** - * Pattern to extract the data content from SSE data field lines. Matches lines - * starting with "data:" and captures the remaining content. - */ - private static final Pattern EVENT_DATA_PATTERN = Pattern.compile("^data:(.+)$", Pattern.MULTILINE); - - /** - * Pattern to extract the event ID from SSE id field lines. Matches lines starting - * with "id:" and captures the ID value. - */ - private static final Pattern EVENT_ID_PATTERN = Pattern.compile("^id:(.+)$", Pattern.MULTILINE); - - /** - * Pattern to extract the event type from SSE event field lines. Matches lines - * starting with "event:" and captures the event type. - */ - private static final Pattern EVENT_TYPE_PATTERN = Pattern.compile("^event:(.+)$", Pattern.MULTILINE); - - /** - * Record class representing a Server-Sent Event with its standard fields. - * - * @param id the event ID (may be null) - * @param type the event type (defaults to "message" if not specified in the stream) - * @param data the event payload data - */ - public static record SseEvent(String id, String type, String data) { - } - - /** - * Interface for handling SSE events and errors. Implementations can process received - * events and handle any errors that occur during the SSE connection. - */ - public interface SseEventHandler { - - /** - * Called when an SSE event is received. - * @param event the received SSE event containing id, type, and data - */ - void onEvent(SseEvent event); - - /** - * Called when an error occurs during the SSE connection. - * @param error the error that occurred - */ - void onError(Throwable error); - - } - - /** - * Creates a new FlowSseClient with the specified HTTP client. - * @param httpClient the {@link HttpClient} instance to use for SSE connections - */ - public FlowSseClient(HttpClient httpClient) { - this(httpClient, HttpRequest.newBuilder()); - } - - /** - * Creates a new FlowSseClient with the specified HTTP client and request builder. - * @param httpClient the {@link HttpClient} instance to use for SSE connections - * @param requestBuilder the {@link HttpRequest.Builder} to use for SSE requests - */ - public FlowSseClient(HttpClient httpClient, HttpRequest.Builder requestBuilder) { - this.httpClient = httpClient; - this.requestBuilder = requestBuilder; - } - - /** - * Subscribes to an SSE endpoint and processes the event stream. - * - *

- * This method establishes a connection to the specified URL and begins processing the - * SSE stream. Events are parsed and delivered to the provided event handler. The - * connection remains active until either an error occurs or the server closes the - * connection. - * @param url the SSE endpoint URL to connect to - * @param eventHandler the handler that will receive SSE events and error - * notifications - * @throws RuntimeException if the connection fails with a non-200 status code - */ - public void subscribe(String url, SseEventHandler eventHandler) { - HttpRequest request = this.requestBuilder.copy() - .uri(URI.create(url)) - .header("Accept", "text/event-stream") - .header("Cache-Control", "no-cache") - .GET() - .build(); - - StringBuilder eventBuilder = new StringBuilder(); - AtomicReference currentEventId = new AtomicReference<>(); - AtomicReference currentEventType = new AtomicReference<>("message"); - - Flow.Subscriber lineSubscriber = new Flow.Subscriber<>() { - private Flow.Subscription subscription; - - @Override - public void onSubscribe(Flow.Subscription subscription) { - this.subscription = subscription; - subscription.request(Long.MAX_VALUE); - } - - @Override - public void onNext(String line) { - if (line.isEmpty()) { - // Empty line means end of event - if (eventBuilder.length() > 0) { - String eventData = eventBuilder.toString(); - SseEvent event = new SseEvent(currentEventId.get(), currentEventType.get(), eventData.trim()); - eventHandler.onEvent(event); - eventBuilder.setLength(0); - } - } - else { - if (line.startsWith("data:")) { - var matcher = EVENT_DATA_PATTERN.matcher(line); - if (matcher.find()) { - eventBuilder.append(matcher.group(1).trim()).append("\n"); - } - } - else if (line.startsWith("id:")) { - var matcher = EVENT_ID_PATTERN.matcher(line); - if (matcher.find()) { - currentEventId.set(matcher.group(1).trim()); - } - } - else if (line.startsWith("event:")) { - var matcher = EVENT_TYPE_PATTERN.matcher(line); - if (matcher.find()) { - currentEventType.set(matcher.group(1).trim()); - } - } - } - subscription.request(1); - } - - @Override - public void onError(Throwable throwable) { - eventHandler.onError(throwable); - } - - @Override - public void onComplete() { - // Handle any remaining event data - if (eventBuilder.length() > 0) { - String eventData = eventBuilder.toString(); - SseEvent event = new SseEvent(currentEventId.get(), currentEventType.get(), eventData.trim()); - eventHandler.onEvent(event); - } - } - }; - - Function, HttpResponse.BodySubscriber> subscriberFactory = subscriber -> HttpResponse.BodySubscribers - .fromLineSubscriber(subscriber); - - CompletableFuture> future = this.httpClient.sendAsync(request, - info -> subscriberFactory.apply(lineSubscriber)); - - future.thenAccept(response -> { - int status = response.statusCode(); - if (status != 200 && status != 201 && status != 202 && status != 206) { - throw new RuntimeException("Failed to connect to SSE stream. Unexpected status code: " + status); - } - }).exceptionally(throwable -> { - eventHandler.onError(throwable); - return null; - }); - } - -} diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java deleted file mode 100644 index d951349d1..000000000 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java +++ /dev/null @@ -1,470 +0,0 @@ -/* - * Copyright 2024 - 2024 the original author or authors. - */ -package io.modelcontextprotocol.client.transport; - -import java.io.IOException; -import java.net.URI; -import java.net.http.HttpClient; -import java.net.http.HttpRequest; -import java.net.http.HttpResponse; -import java.time.Duration; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Consumer; -import java.util.function.Function; - -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.client.transport.FlowSseClient.SseEvent; -import io.modelcontextprotocol.spec.McpClientTransport; -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; -import io.modelcontextprotocol.util.Assert; -import io.modelcontextprotocol.util.Utils; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import reactor.core.publisher.Mono; - -/** - * Server-Sent Events (SSE) implementation of the - * {@link io.modelcontextprotocol.spec.McpTransport} that follows the MCP HTTP with SSE - * transport specification, using Java's HttpClient. - * - *

- * This transport implementation establishes a bidirectional communication channel between - * client and server using SSE for server-to-client messages and HTTP POST requests for - * client-to-server messages. The transport: - *

    - *
  • Establishes an SSE connection to receive server messages
  • - *
  • Handles endpoint discovery through SSE events
  • - *
  • Manages message serialization/deserialization using Jackson
  • - *
  • Provides graceful connection termination
  • - *
- * - *

- * The transport supports two types of SSE events: - *

    - *
  • 'endpoint' - Contains the URL for sending client messages
  • - *
  • 'message' - Contains JSON-RPC message payload
  • - *
- * - * @author Christian Tzolov - * @see io.modelcontextprotocol.spec.McpTransport - * @see io.modelcontextprotocol.spec.McpClientTransport - */ -public class HttpClientSseClientTransport implements McpClientTransport { - - private static final Logger logger = LoggerFactory.getLogger(HttpClientSseClientTransport.class); - - /** SSE event type for JSON-RPC messages */ - private static final String MESSAGE_EVENT_TYPE = "message"; - - /** SSE event type for endpoint discovery */ - private static final String ENDPOINT_EVENT_TYPE = "endpoint"; - - /** Default SSE endpoint path */ - private static final String DEFAULT_SSE_ENDPOINT = "/sse"; - - /** Base URI for the MCP server */ - private final URI baseUri; - - /** SSE endpoint path */ - private final String sseEndpoint; - - /** SSE client for handling server-sent events. Uses the /sse endpoint */ - private final FlowSseClient sseClient; - - /** - * HTTP client for sending messages to the server. Uses HTTP POST over the message - * endpoint - */ - private final HttpClient httpClient; - - /** HTTP request builder for building requests to send messages to the server */ - private final HttpRequest.Builder requestBuilder; - - /** JSON object mapper for message serialization/deserialization */ - protected ObjectMapper objectMapper; - - /** Flag indicating if the transport is in closing state */ - private volatile boolean isClosing = false; - - /** Latch for coordinating endpoint discovery */ - private final CountDownLatch closeLatch = new CountDownLatch(1); - - /** Holds the discovered message endpoint URL */ - private final AtomicReference messageEndpoint = new AtomicReference<>(); - - /** Holds the SSE connection future */ - private final AtomicReference> connectionFuture = new AtomicReference<>(); - - /** - * Creates a new transport instance with default HTTP client and object mapper. - * @param baseUri the base URI of the MCP server - * @deprecated Use {@link HttpClientSseClientTransport#builder(String)} instead. This - * constructor will be removed in future versions. - */ - @Deprecated(forRemoval = true) - public HttpClientSseClientTransport(String baseUri) { - this(HttpClient.newBuilder(), baseUri, new ObjectMapper()); - } - - /** - * Creates a new transport instance with custom HTTP client builder and object mapper. - * @param clientBuilder the HTTP client builder to use - * @param baseUri the base URI of the MCP server - * @param objectMapper the object mapper for JSON serialization/deserialization - * @throws IllegalArgumentException if objectMapper or clientBuilder is null - * @deprecated Use {@link HttpClientSseClientTransport#builder(String)} instead. This - * constructor will be removed in future versions. - */ - @Deprecated(forRemoval = true) - public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, String baseUri, ObjectMapper objectMapper) { - this(clientBuilder, baseUri, DEFAULT_SSE_ENDPOINT, objectMapper); - } - - /** - * Creates a new transport instance with custom HTTP client builder and object mapper. - * @param clientBuilder the HTTP client builder to use - * @param baseUri the base URI of the MCP server - * @param sseEndpoint the SSE endpoint path - * @param objectMapper the object mapper for JSON serialization/deserialization - * @throws IllegalArgumentException if objectMapper or clientBuilder is null - * @deprecated Use {@link HttpClientSseClientTransport#builder(String)} instead. This - * constructor will be removed in future versions. - */ - @Deprecated(forRemoval = true) - public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, String baseUri, String sseEndpoint, - ObjectMapper objectMapper) { - this(clientBuilder, HttpRequest.newBuilder(), baseUri, sseEndpoint, objectMapper); - } - - /** - * Creates a new transport instance with custom HTTP client builder, object mapper, - * and headers. - * @param clientBuilder the HTTP client builder to use - * @param requestBuilder the HTTP request builder to use - * @param baseUri the base URI of the MCP server - * @param sseEndpoint the SSE endpoint path - * @param objectMapper the object mapper for JSON serialization/deserialization - * @throws IllegalArgumentException if objectMapper, clientBuilder, or headers is null - * @deprecated Use {@link HttpClientSseClientTransport#builder(String)} instead. This - * constructor will be removed in future versions. - */ - @Deprecated(forRemoval = true) - public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, HttpRequest.Builder requestBuilder, - String baseUri, String sseEndpoint, ObjectMapper objectMapper) { - this(clientBuilder.connectTimeout(Duration.ofSeconds(10)).build(), requestBuilder, baseUri, sseEndpoint, - objectMapper); - } - - /** - * Creates a new transport instance with custom HTTP client builder, object mapper, - * and headers. - * @param httpClient the HTTP client to use - * @param requestBuilder the HTTP request builder to use - * @param baseUri the base URI of the MCP server - * @param sseEndpoint the SSE endpoint path - * @param objectMapper the object mapper for JSON serialization/deserialization - * @throws IllegalArgumentException if objectMapper, clientBuilder, or headers is null - */ - HttpClientSseClientTransport(HttpClient httpClient, HttpRequest.Builder requestBuilder, String baseUri, - String sseEndpoint, ObjectMapper objectMapper) { - Assert.notNull(objectMapper, "ObjectMapper must not be null"); - Assert.hasText(baseUri, "baseUri must not be empty"); - Assert.hasText(sseEndpoint, "sseEndpoint must not be empty"); - Assert.notNull(httpClient, "httpClient must not be null"); - Assert.notNull(requestBuilder, "requestBuilder must not be null"); - this.baseUri = URI.create(baseUri); - this.sseEndpoint = sseEndpoint; - this.objectMapper = objectMapper; - this.httpClient = httpClient; - this.requestBuilder = requestBuilder; - - this.sseClient = new FlowSseClient(this.httpClient, requestBuilder); - } - - /** - * Creates a new builder for {@link HttpClientSseClientTransport}. - * @param baseUri the base URI of the MCP server - * @return a new builder instance - */ - public static Builder builder(String baseUri) { - return new Builder().baseUri(baseUri); - } - - /** - * Builder for {@link HttpClientSseClientTransport}. - */ - public static class Builder { - - private String baseUri; - - private String sseEndpoint = DEFAULT_SSE_ENDPOINT; - - private HttpClient.Builder clientBuilder = HttpClient.newBuilder() - .version(HttpClient.Version.HTTP_1_1) - .connectTimeout(Duration.ofSeconds(10)); - - private ObjectMapper objectMapper = new ObjectMapper(); - - private HttpRequest.Builder requestBuilder = HttpRequest.newBuilder() - .header("Content-Type", "application/json"); - - /** - * Creates a new builder instance. - */ - Builder() { - // Default constructor - } - - /** - * Creates a new builder with the specified base URI. - * @param baseUri the base URI of the MCP server - * @deprecated Use {@link HttpClientSseClientTransport#builder(String)} instead. - * This constructor is deprecated and will be removed or made {@code protected} or - * {@code private} in a future release. - */ - @Deprecated(forRemoval = true) - public Builder(String baseUri) { - Assert.hasText(baseUri, "baseUri must not be empty"); - this.baseUri = baseUri; - } - - /** - * Sets the base URI. - * @param baseUri the base URI - * @return this builder - */ - Builder baseUri(String baseUri) { - Assert.hasText(baseUri, "baseUri must not be empty"); - this.baseUri = baseUri; - return this; - } - - /** - * Sets the SSE endpoint path. - * @param sseEndpoint the SSE endpoint path - * @return this builder - */ - public Builder sseEndpoint(String sseEndpoint) { - Assert.hasText(sseEndpoint, "sseEndpoint must not be empty"); - this.sseEndpoint = sseEndpoint; - return this; - } - - /** - * Sets the HTTP client builder. - * @param clientBuilder the HTTP client builder - * @return this builder - */ - public Builder clientBuilder(HttpClient.Builder clientBuilder) { - Assert.notNull(clientBuilder, "clientBuilder must not be null"); - this.clientBuilder = clientBuilder; - return this; - } - - /** - * Customizes the HTTP client builder. - * @param clientCustomizer the consumer to customize the HTTP client builder - * @return this builder - */ - public Builder customizeClient(final Consumer clientCustomizer) { - Assert.notNull(clientCustomizer, "clientCustomizer must not be null"); - clientCustomizer.accept(clientBuilder); - return this; - } - - /** - * Sets the HTTP request builder. - * @param requestBuilder the HTTP request builder - * @return this builder - */ - public Builder requestBuilder(HttpRequest.Builder requestBuilder) { - Assert.notNull(requestBuilder, "requestBuilder must not be null"); - this.requestBuilder = requestBuilder; - return this; - } - - /** - * Customizes the HTTP client builder. - * @param requestCustomizer the consumer to customize the HTTP request builder - * @return this builder - */ - public Builder customizeRequest(final Consumer requestCustomizer) { - Assert.notNull(requestCustomizer, "requestCustomizer must not be null"); - requestCustomizer.accept(requestBuilder); - return this; - } - - /** - * Sets the object mapper for JSON serialization/deserialization. - * @param objectMapper the object mapper - * @return this builder - */ - public Builder objectMapper(ObjectMapper objectMapper) { - Assert.notNull(objectMapper, "objectMapper must not be null"); - this.objectMapper = objectMapper; - return this; - } - - /** - * Builds a new {@link HttpClientSseClientTransport} instance. - * @return a new transport instance - */ - public HttpClientSseClientTransport build() { - return new HttpClientSseClientTransport(clientBuilder.build(), requestBuilder, baseUri, sseEndpoint, - objectMapper); - } - - } - - /** - * Establishes the SSE connection with the server and sets up message handling. - * - *

- * This method: - *

    - *
  • Initiates the SSE connection
  • - *
  • Handles endpoint discovery events
  • - *
  • Processes incoming JSON-RPC messages
  • - *
- * @param handler the function to process received JSON-RPC messages - * @return a Mono that completes when the connection is established - */ - @Override - public Mono connect(Function, Mono> handler) { - CompletableFuture future = new CompletableFuture<>(); - connectionFuture.set(future); - - URI clientUri = Utils.resolveUri(this.baseUri, this.sseEndpoint); - sseClient.subscribe(clientUri.toString(), new FlowSseClient.SseEventHandler() { - @Override - public void onEvent(SseEvent event) { - if (isClosing) { - return; - } - - try { - if (ENDPOINT_EVENT_TYPE.equals(event.type())) { - String endpoint = event.data(); - messageEndpoint.set(endpoint); - closeLatch.countDown(); - future.complete(null); - } - else if (MESSAGE_EVENT_TYPE.equals(event.type())) { - JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, event.data()); - handler.apply(Mono.just(message)).subscribe(); - } - else { - logger.error("Received unrecognized SSE event type: {}", event.type()); - } - } - catch (IOException e) { - logger.error("Error processing SSE event", e); - future.completeExceptionally(e); - } - } - - @Override - public void onError(Throwable error) { - if (!isClosing) { - logger.error("SSE connection error", error); - future.completeExceptionally(error); - } - } - }); - - return Mono.fromFuture(future); - } - - /** - * Sends a JSON-RPC message to the server. - * - *

- * This method waits for the message endpoint to be discovered before sending the - * message. The message is serialized to JSON and sent as an HTTP POST request. - * @param message the JSON-RPC message to send - * @return a Mono that completes when the message is sent - * @throws McpError if the message endpoint is not available or the wait times out - */ - @Override - public Mono sendMessage(JSONRPCMessage message) { - if (isClosing) { - return Mono.empty(); - } - - try { - if (!closeLatch.await(10, TimeUnit.SECONDS)) { - return Mono.error(new McpError("Failed to wait for the message endpoint")); - } - } - catch (InterruptedException e) { - return Mono.error(new McpError("Failed to wait for the message endpoint")); - } - - String endpoint = messageEndpoint.get(); - if (endpoint == null) { - return Mono.error(new McpError("No message endpoint available")); - } - - try { - String jsonText = this.objectMapper.writeValueAsString(message); - URI requestUri = Utils.resolveUri(baseUri, endpoint); - HttpRequest request = this.requestBuilder.copy() - .uri(requestUri) - .POST(HttpRequest.BodyPublishers.ofString(jsonText)) - .build(); - - return Mono.fromFuture( - httpClient.sendAsync(request, HttpResponse.BodyHandlers.discarding()).thenAccept(response -> { - if (response.statusCode() != 200 && response.statusCode() != 201 && response.statusCode() != 202 - && response.statusCode() != 206) { - logger.error("Error sending message: {}", response.statusCode()); - } - })); - } - catch (IOException e) { - if (!isClosing) { - return Mono.error(new RuntimeException("Failed to serialize message", e)); - } - return Mono.empty(); - } - } - - /** - * Gracefully closes the transport connection. - * - *

- * Sets the closing flag and cancels any pending connection future. This prevents new - * messages from being sent and allows ongoing operations to complete. - * @return a Mono that completes when the closing process is initiated - */ - @Override - public Mono closeGracefully() { - return Mono.fromRunnable(() -> { - isClosing = true; - CompletableFuture future = connectionFuture.get(); - if (future != null && !future.isDone()) { - future.cancel(true); - } - }); - } - - /** - * Unmarshal data to the specified type using the configured object mapper. - * @param data the data to unmarshal - * @param typeRef the type reference for the target type - * @param the target type - * @return the unmarshalled object - */ - @Override - public T unmarshalFrom(Object data, TypeReference typeRef) { - return this.objectMapper.convertValue(data, typeRef); - } - -} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java deleted file mode 100644 index 02ad955b9..000000000 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ /dev/null @@ -1,750 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.server; - -import java.time.Duration; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.UUID; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.CopyOnWriteArrayList; -import java.util.function.BiFunction; - -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.spec.McpClientSession; -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpSchema.CallToolResult; -import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; -import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; -import io.modelcontextprotocol.spec.McpSchema.ResourceTemplate; -import io.modelcontextprotocol.spec.McpSchema.SetLevelRequest; -import io.modelcontextprotocol.spec.McpSchema.Tool; -import io.modelcontextprotocol.spec.McpServerSession; -import io.modelcontextprotocol.spec.McpServerTransportProvider; -import io.modelcontextprotocol.util.DeafaultMcpUriTemplateManagerFactory; -import io.modelcontextprotocol.util.McpUriTemplateManagerFactory; -import io.modelcontextprotocol.util.Utils; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - -/** - * The Model Context Protocol (MCP) server implementation that provides asynchronous - * communication using Project Reactor's Mono and Flux types. - * - *

- * This server implements the MCP specification, enabling AI models to expose tools, - * resources, and prompts through a standardized interface. Key features include: - *

    - *
  • Asynchronous communication using reactive programming patterns - *
  • Dynamic tool registration and management - *
  • Resource handling with URI-based addressing - *
  • Prompt template management - *
  • Real-time client notifications for state changes - *
  • Structured logging with configurable severity levels - *
  • Support for client-side AI model sampling - *
- * - *

- * The server follows a lifecycle: - *

    - *
  1. Initialization - Accepts client connections and negotiates capabilities - *
  2. Normal Operation - Handles client requests and sends notifications - *
  3. Graceful Shutdown - Ensures clean connection termination - *
- * - *

- * This implementation uses Project Reactor for non-blocking operations, making it - * suitable for high-throughput scenarios and reactive applications. All operations return - * Mono or Flux types that can be composed into reactive pipelines. - * - *

- * The server supports runtime modification of its capabilities through methods like - * {@link #addTool}, {@link #addResource}, and {@link #addPrompt}, automatically notifying - * connected clients of changes when configured to do so. - * - * @author Christian Tzolov - * @author Dariusz Jędrzejczyk - * @author Jihoon Kim - * @see McpServer - * @see McpSchema - * @see McpClientSession - */ -public class McpAsyncServer { - - private static final Logger logger = LoggerFactory.getLogger(McpAsyncServer.class); - - private final McpServerTransportProvider mcpTransportProvider; - - private final ObjectMapper objectMapper; - - private final McpSchema.ServerCapabilities serverCapabilities; - - private final McpSchema.Implementation serverInfo; - - private final String instructions; - - private final CopyOnWriteArrayList tools = new CopyOnWriteArrayList<>(); - - private final CopyOnWriteArrayList resourceTemplates = new CopyOnWriteArrayList<>(); - - private final ConcurrentHashMap resources = new ConcurrentHashMap<>(); - - private final ConcurrentHashMap prompts = new ConcurrentHashMap<>(); - - // FIXME: this field is deprecated and should be remvoed together with the - // broadcasting loggingNotification. - private LoggingLevel minLoggingLevel = LoggingLevel.DEBUG; - - private final ConcurrentHashMap completions = new ConcurrentHashMap<>(); - - private List protocolVersions = List.of(McpSchema.LATEST_PROTOCOL_VERSION); - - private McpUriTemplateManagerFactory uriTemplateManagerFactory = new DeafaultMcpUriTemplateManagerFactory(); - - /** - * Create a new McpAsyncServer with the given transport provider and capabilities. - * @param mcpTransportProvider The transport layer implementation for MCP - * communication. - * @param features The MCP server supported features. - * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization - */ - McpAsyncServer(McpServerTransportProvider mcpTransportProvider, ObjectMapper objectMapper, - McpServerFeatures.Async features, Duration requestTimeout, - McpUriTemplateManagerFactory uriTemplateManagerFactory) { - this.mcpTransportProvider = mcpTransportProvider; - this.objectMapper = objectMapper; - this.serverInfo = features.serverInfo(); - this.serverCapabilities = features.serverCapabilities(); - this.instructions = features.instructions(); - this.tools.addAll(features.tools()); - this.resources.putAll(features.resources()); - this.resourceTemplates.addAll(features.resourceTemplates()); - this.prompts.putAll(features.prompts()); - this.completions.putAll(features.completions()); - this.uriTemplateManagerFactory = uriTemplateManagerFactory; - - Map> requestHandlers = new HashMap<>(); - - // Initialize request handlers for standard MCP methods - - // Ping MUST respond with an empty data, but not NULL response. - requestHandlers.put(McpSchema.METHOD_PING, (exchange, params) -> Mono.just(Map.of())); - - // Add tools API handlers if the tool capability is enabled - if (this.serverCapabilities.tools() != null) { - requestHandlers.put(McpSchema.METHOD_TOOLS_LIST, toolsListRequestHandler()); - requestHandlers.put(McpSchema.METHOD_TOOLS_CALL, toolsCallRequestHandler()); - } - - // Add resources API handlers if provided - if (this.serverCapabilities.resources() != null) { - requestHandlers.put(McpSchema.METHOD_RESOURCES_LIST, resourcesListRequestHandler()); - requestHandlers.put(McpSchema.METHOD_RESOURCES_READ, resourcesReadRequestHandler()); - requestHandlers.put(McpSchema.METHOD_RESOURCES_TEMPLATES_LIST, resourceTemplateListRequestHandler()); - } - - // Add prompts API handlers if provider exists - if (this.serverCapabilities.prompts() != null) { - requestHandlers.put(McpSchema.METHOD_PROMPT_LIST, promptsListRequestHandler()); - requestHandlers.put(McpSchema.METHOD_PROMPT_GET, promptsGetRequestHandler()); - } - - // Add logging API handlers if the logging capability is enabled - if (this.serverCapabilities.logging() != null) { - requestHandlers.put(McpSchema.METHOD_LOGGING_SET_LEVEL, setLoggerRequestHandler()); - } - - // Add completion API handlers if the completion capability is enabled - if (this.serverCapabilities.completions() != null) { - requestHandlers.put(McpSchema.METHOD_COMPLETION_COMPLETE, completionCompleteRequestHandler()); - } - - Map notificationHandlers = new HashMap<>(); - - notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_INITIALIZED, (exchange, params) -> Mono.empty()); - - List, Mono>> rootsChangeConsumers = features - .rootsChangeConsumers(); - - if (Utils.isEmpty(rootsChangeConsumers)) { - rootsChangeConsumers = List.of((exchange, roots) -> Mono.fromRunnable(() -> logger - .warn("Roots list changed notification, but no consumers provided. Roots list changed: {}", roots))); - } - - notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED, - asyncRootsListChangedNotificationHandler(rootsChangeConsumers)); - - mcpTransportProvider.setSessionFactory( - transport -> new McpServerSession(UUID.randomUUID().toString(), requestTimeout, transport, - this::asyncInitializeRequestHandler, Mono::empty, requestHandlers, notificationHandlers)); - } - - // --------------------------------------- - // Lifecycle Management - // --------------------------------------- - private Mono asyncInitializeRequestHandler( - McpSchema.InitializeRequest initializeRequest) { - return Mono.defer(() -> { - logger.info("Client initialize request - Protocol: {}, Capabilities: {}, Info: {}", - initializeRequest.protocolVersion(), initializeRequest.capabilities(), - initializeRequest.clientInfo()); - - // The server MUST respond with the highest protocol version it supports - // if - // it does not support the requested (e.g. Client) version. - String serverProtocolVersion = this.protocolVersions.get(this.protocolVersions.size() - 1); - - if (this.protocolVersions.contains(initializeRequest.protocolVersion())) { - // If the server supports the requested protocol version, it MUST - // respond - // with the same version. - serverProtocolVersion = initializeRequest.protocolVersion(); - } - else { - logger.warn( - "Client requested unsupported protocol version: {}, so the server will suggest the {} version instead", - initializeRequest.protocolVersion(), serverProtocolVersion); - } - - return Mono.just(new McpSchema.InitializeResult(serverProtocolVersion, this.serverCapabilities, - this.serverInfo, this.instructions)); - }); - } - - /** - * Get the server capabilities that define the supported features and functionality. - * @return The server capabilities - */ - public McpSchema.ServerCapabilities getServerCapabilities() { - return this.serverCapabilities; - } - - /** - * Get the server implementation information. - * @return The server implementation details - */ - public McpSchema.Implementation getServerInfo() { - return this.serverInfo; - } - - /** - * Gracefully closes the server, allowing any in-progress operations to complete. - * @return A Mono that completes when the server has been closed - */ - public Mono closeGracefully() { - return this.mcpTransportProvider.closeGracefully(); - } - - /** - * Close the server immediately. - */ - public void close() { - this.mcpTransportProvider.close(); - } - - private McpServerSession.NotificationHandler asyncRootsListChangedNotificationHandler( - List, Mono>> rootsChangeConsumers) { - return (exchange, params) -> exchange.listRoots() - .flatMap(listRootsResult -> Flux.fromIterable(rootsChangeConsumers) - .flatMap(consumer -> consumer.apply(exchange, listRootsResult.roots())) - .onErrorResume(error -> { - logger.error("Error handling roots list change notification", error); - return Mono.empty(); - }) - .then()); - } - - // --------------------------------------- - // Tool Management - // --------------------------------------- - - /** - * Add a new tool specification at runtime. - * @param toolSpecification The tool specification to add - * @return Mono that completes when clients have been notified of the change - */ - public Mono addTool(McpServerFeatures.AsyncToolSpecification toolSpecification) { - if (toolSpecification == null) { - return Mono.error(new McpError("Tool specification must not be null")); - } - if (toolSpecification.tool() == null) { - return Mono.error(new McpError("Tool must not be null")); - } - if (toolSpecification.call() == null) { - return Mono.error(new McpError("Tool call handler must not be null")); - } - if (this.serverCapabilities.tools() == null) { - return Mono.error(new McpError("Server must be configured with tool capabilities")); - } - - return Mono.defer(() -> { - // Check for duplicate tool names - if (this.tools.stream().anyMatch(th -> th.tool().name().equals(toolSpecification.tool().name()))) { - return Mono - .error(new McpError("Tool with name '" + toolSpecification.tool().name() + "' already exists")); - } - - this.tools.add(toolSpecification); - logger.debug("Added tool handler: {}", toolSpecification.tool().name()); - - if (this.serverCapabilities.tools().listChanged()) { - return notifyToolsListChanged(); - } - return Mono.empty(); - }); - } - - /** - * Remove a tool handler at runtime. - * @param toolName The name of the tool handler to remove - * @return Mono that completes when clients have been notified of the change - */ - public Mono removeTool(String toolName) { - if (toolName == null) { - return Mono.error(new McpError("Tool name must not be null")); - } - if (this.serverCapabilities.tools() == null) { - return Mono.error(new McpError("Server must be configured with tool capabilities")); - } - - return Mono.defer(() -> { - boolean removed = this.tools - .removeIf(toolSpecification -> toolSpecification.tool().name().equals(toolName)); - if (removed) { - logger.debug("Removed tool handler: {}", toolName); - if (this.serverCapabilities.tools().listChanged()) { - return notifyToolsListChanged(); - } - return Mono.empty(); - } - return Mono.error(new McpError("Tool with name '" + toolName + "' not found")); - }); - } - - /** - * Notifies clients that the list of available tools has changed. - * @return A Mono that completes when all clients have been notified - */ - public Mono notifyToolsListChanged() { - return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_TOOLS_LIST_CHANGED, null); - } - - private McpServerSession.RequestHandler toolsListRequestHandler() { - return (exchange, params) -> { - List tools = this.tools.stream().map(McpServerFeatures.AsyncToolSpecification::tool).toList(); - - return Mono.just(new McpSchema.ListToolsResult(tools, null)); - }; - } - - private McpServerSession.RequestHandler toolsCallRequestHandler() { - return (exchange, params) -> { - McpSchema.CallToolRequest callToolRequest = objectMapper.convertValue(params, - new TypeReference() { - }); - - Optional toolSpecification = this.tools.stream() - .filter(tr -> callToolRequest.name().equals(tr.tool().name())) - .findAny(); - - if (toolSpecification.isEmpty()) { - return Mono.error(new McpError("Tool not found: " + callToolRequest.name())); - } - - return toolSpecification.map(tool -> tool.call().apply(exchange, callToolRequest.arguments())) - .orElse(Mono.error(new McpError("Tool not found: " + callToolRequest.name()))); - }; - } - - // --------------------------------------- - // Resource Management - // --------------------------------------- - - /** - * Add a new resource handler at runtime. - * @param resourceSpecification The resource handler to add - * @return Mono that completes when clients have been notified of the change - */ - public Mono addResource(McpServerFeatures.AsyncResourceSpecification resourceSpecification) { - if (resourceSpecification == null || resourceSpecification.resource() == null) { - return Mono.error(new McpError("Resource must not be null")); - } - - if (this.serverCapabilities.resources() == null) { - return Mono.error(new McpError("Server must be configured with resource capabilities")); - } - - return Mono.defer(() -> { - if (this.resources.putIfAbsent(resourceSpecification.resource().uri(), resourceSpecification) != null) { - return Mono.error(new McpError( - "Resource with URI '" + resourceSpecification.resource().uri() + "' already exists")); - } - logger.debug("Added resource handler: {}", resourceSpecification.resource().uri()); - if (this.serverCapabilities.resources().listChanged()) { - return notifyResourcesListChanged(); - } - return Mono.empty(); - }); - } - - /** - * Remove a resource handler at runtime. - * @param resourceUri The URI of the resource handler to remove - * @return Mono that completes when clients have been notified of the change - */ - public Mono removeResource(String resourceUri) { - if (resourceUri == null) { - return Mono.error(new McpError("Resource URI must not be null")); - } - if (this.serverCapabilities.resources() == null) { - return Mono.error(new McpError("Server must be configured with resource capabilities")); - } - - return Mono.defer(() -> { - McpServerFeatures.AsyncResourceSpecification removed = this.resources.remove(resourceUri); - if (removed != null) { - logger.debug("Removed resource handler: {}", resourceUri); - if (this.serverCapabilities.resources().listChanged()) { - return notifyResourcesListChanged(); - } - return Mono.empty(); - } - return Mono.error(new McpError("Resource with URI '" + resourceUri + "' not found")); - }); - } - - /** - * Notifies clients that the list of available resources has changed. - * @return A Mono that completes when all clients have been notified - */ - public Mono notifyResourcesListChanged() { - return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_RESOURCES_LIST_CHANGED, null); - } - - /** - * Notifies clients that the resources have updated. - * @return A Mono that completes when all clients have been notified - */ - public Mono notifyResourcesUpdated(McpSchema.ResourcesUpdatedNotification resourcesUpdatedNotification) { - return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_RESOURCES_UPDATED, - resourcesUpdatedNotification); - } - - private McpServerSession.RequestHandler resourcesListRequestHandler() { - return (exchange, params) -> { - var resourceList = this.resources.values() - .stream() - .map(McpServerFeatures.AsyncResourceSpecification::resource) - .toList(); - return Mono.just(new McpSchema.ListResourcesResult(resourceList, null)); - }; - } - - private McpServerSession.RequestHandler resourceTemplateListRequestHandler() { - return (exchange, params) -> Mono - .just(new McpSchema.ListResourceTemplatesResult(this.getResourceTemplates(), null)); - - } - - private List getResourceTemplates() { - var list = new ArrayList<>(this.resourceTemplates); - List resourceTemplates = this.resources.keySet() - .stream() - .filter(uri -> uri.contains("{")) - .map(uri -> { - var resource = this.resources.get(uri).resource(); - var template = new McpSchema.ResourceTemplate(resource.uri(), resource.name(), resource.description(), - resource.mimeType(), resource.annotations()); - return template; - }) - .toList(); - - list.addAll(resourceTemplates); - - return list; - } - - private McpServerSession.RequestHandler resourcesReadRequestHandler() { - return (exchange, params) -> { - McpSchema.ReadResourceRequest resourceRequest = objectMapper.convertValue(params, - new TypeReference() { - }); - var resourceUri = resourceRequest.uri(); - - McpServerFeatures.AsyncResourceSpecification specification = this.resources.values() - .stream() - .filter(resourceSpecification -> this.uriTemplateManagerFactory - .create(resourceSpecification.resource().uri()) - .matches(resourceUri)) - .findFirst() - .orElseThrow(() -> new McpError("Resource not found: " + resourceUri)); - - return specification.readHandler().apply(exchange, resourceRequest); - }; - } - - // --------------------------------------- - // Prompt Management - // --------------------------------------- - - /** - * Add a new prompt handler at runtime. - * @param promptSpecification The prompt handler to add - * @return Mono that completes when clients have been notified of the change - */ - public Mono addPrompt(McpServerFeatures.AsyncPromptSpecification promptSpecification) { - if (promptSpecification == null) { - return Mono.error(new McpError("Prompt specification must not be null")); - } - if (this.serverCapabilities.prompts() == null) { - return Mono.error(new McpError("Server must be configured with prompt capabilities")); - } - - return Mono.defer(() -> { - McpServerFeatures.AsyncPromptSpecification specification = this.prompts - .putIfAbsent(promptSpecification.prompt().name(), promptSpecification); - if (specification != null) { - return Mono.error( - new McpError("Prompt with name '" + promptSpecification.prompt().name() + "' already exists")); - } - - logger.debug("Added prompt handler: {}", promptSpecification.prompt().name()); - - // Servers that declared the listChanged capability SHOULD send a - // notification, - // when the list of available prompts changes - if (this.serverCapabilities.prompts().listChanged()) { - return notifyPromptsListChanged(); - } - return Mono.empty(); - }); - } - - /** - * Remove a prompt handler at runtime. - * @param promptName The name of the prompt handler to remove - * @return Mono that completes when clients have been notified of the change - */ - public Mono removePrompt(String promptName) { - if (promptName == null) { - return Mono.error(new McpError("Prompt name must not be null")); - } - if (this.serverCapabilities.prompts() == null) { - return Mono.error(new McpError("Server must be configured with prompt capabilities")); - } - - return Mono.defer(() -> { - McpServerFeatures.AsyncPromptSpecification removed = this.prompts.remove(promptName); - - if (removed != null) { - logger.debug("Removed prompt handler: {}", promptName); - // Servers that declared the listChanged capability SHOULD send a - // notification, when the list of available prompts changes - if (this.serverCapabilities.prompts().listChanged()) { - return this.notifyPromptsListChanged(); - } - return Mono.empty(); - } - return Mono.error(new McpError("Prompt with name '" + promptName + "' not found")); - }); - } - - /** - * Notifies clients that the list of available prompts has changed. - * @return A Mono that completes when all clients have been notified - */ - public Mono notifyPromptsListChanged() { - return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_PROMPTS_LIST_CHANGED, null); - } - - private McpServerSession.RequestHandler promptsListRequestHandler() { - return (exchange, params) -> { - // TODO: Implement pagination - // McpSchema.PaginatedRequest request = objectMapper.convertValue(params, - // new TypeReference() { - // }); - - var promptList = this.prompts.values() - .stream() - .map(McpServerFeatures.AsyncPromptSpecification::prompt) - .toList(); - - return Mono.just(new McpSchema.ListPromptsResult(promptList, null)); - }; - } - - private McpServerSession.RequestHandler promptsGetRequestHandler() { - return (exchange, params) -> { - McpSchema.GetPromptRequest promptRequest = objectMapper.convertValue(params, - new TypeReference() { - }); - - // Implement prompt retrieval logic here - McpServerFeatures.AsyncPromptSpecification specification = this.prompts.get(promptRequest.name()); - if (specification == null) { - return Mono.error(new McpError("Prompt not found: " + promptRequest.name())); - } - - return specification.promptHandler().apply(exchange, promptRequest); - }; - } - - // --------------------------------------- - // Logging Management - // --------------------------------------- - - /** - * This implementation would, incorrectly, broadcast the logging message to all - * connected clients, using a single minLoggingLevel for all of them. Similar to the - * sampling and roots, the logging level should be set per client session and use the - * ServerExchange to send the logging message to the right client. - * @param loggingMessageNotification The logging message to send - * @return A Mono that completes when the notification has been sent - * @deprecated Use - * {@link McpAsyncServerExchange#loggingNotification(LoggingMessageNotification)} - * instead. - */ - @Deprecated - public Mono loggingNotification(LoggingMessageNotification loggingMessageNotification) { - - if (loggingMessageNotification == null) { - return Mono.error(new McpError("Logging message must not be null")); - } - - if (loggingMessageNotification.level().level() < minLoggingLevel.level()) { - return Mono.empty(); - } - - return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_MESSAGE, - loggingMessageNotification); - } - - private McpServerSession.RequestHandler setLoggerRequestHandler() { - return (exchange, params) -> { - return Mono.defer(() -> { - - SetLevelRequest newMinLoggingLevel = objectMapper.convertValue(params, - new TypeReference() { - }); - - exchange.setMinLoggingLevel(newMinLoggingLevel.level()); - - // FIXME: this field is deprecated and should be removed together - // with the broadcasting loggingNotification. - this.minLoggingLevel = newMinLoggingLevel.level(); - - return Mono.just(Map.of()); - }); - }; - } - - private McpServerSession.RequestHandler completionCompleteRequestHandler() { - return (exchange, params) -> { - McpSchema.CompleteRequest request = parseCompletionParams(params); - - if (request.ref() == null) { - return Mono.error(new McpError("ref must not be null")); - } - - if (request.ref().type() == null) { - return Mono.error(new McpError("type must not be null")); - } - - String type = request.ref().type(); - - String argumentName = request.argument().name(); - - // check if the referenced resource exists - if (type.equals("ref/prompt") && request.ref() instanceof McpSchema.PromptReference promptReference) { - McpServerFeatures.AsyncPromptSpecification promptSpec = this.prompts.get(promptReference.name()); - if (promptSpec == null) { - return Mono.error(new McpError("Prompt not found: " + promptReference.name())); - } - if (!promptSpec.prompt() - .arguments() - .stream() - .filter(arg -> arg.name().equals(argumentName)) - .findFirst() - .isPresent()) { - - return Mono.error(new McpError("Argument not found: " + argumentName)); - } - } - - if (type.equals("ref/resource") && request.ref() instanceof McpSchema.ResourceReference resourceReference) { - McpServerFeatures.AsyncResourceSpecification resourceSpec = this.resources.get(resourceReference.uri()); - if (resourceSpec == null) { - return Mono.error(new McpError("Resource not found: " + resourceReference.uri())); - } - if (!uriTemplateManagerFactory.create(resourceSpec.resource().uri()) - .getVariableNames() - .contains(argumentName)) { - return Mono.error(new McpError("Argument not found: " + argumentName)); - } - - } - - McpServerFeatures.AsyncCompletionSpecification specification = this.completions.get(request.ref()); - - if (specification == null) { - return Mono.error(new McpError("AsyncCompletionSpecification not found: " + request.ref())); - } - - return specification.completionHandler().apply(exchange, request); - }; - } - - /** - * Parses the raw JSON-RPC request parameters into a {@link McpSchema.CompleteRequest} - * object. - *

- * This method manually extracts the `ref` and `argument` fields from the input map, - * determines the correct reference type (either prompt or resource), and constructs a - * fully-typed {@code CompleteRequest} instance. - * @param object the raw request parameters, expected to be a Map containing "ref" and - * "argument" entries. - * @return a {@link McpSchema.CompleteRequest} representing the structured completion - * request. - * @throws IllegalArgumentException if the "ref" type is not recognized. - */ - @SuppressWarnings("unchecked") - private McpSchema.CompleteRequest parseCompletionParams(Object object) { - Map params = (Map) object; - Map refMap = (Map) params.get("ref"); - Map argMap = (Map) params.get("argument"); - - String refType = (String) refMap.get("type"); - - McpSchema.CompleteReference ref = switch (refType) { - case "ref/prompt" -> new McpSchema.PromptReference(refType, (String) refMap.get("name")); - case "ref/resource" -> new McpSchema.ResourceReference(refType, (String) refMap.get("uri")); - default -> throw new IllegalArgumentException("Invalid ref type: " + refType); - }; - - String argName = (String) argMap.get("name"); - String argValue = (String) argMap.get("value"); - McpSchema.CompleteRequest.CompleteArgument argument = new McpSchema.CompleteRequest.CompleteArgument(argName, - argValue); - - return new McpSchema.CompleteRequest(ref, argument); - } - - /** - * This method is package-private and used for test only. Should not be called by user - * code. - * @param protocolVersions the Client supported protocol versions. - */ - void setProtocolVersions(List protocolVersions) { - this.protocolVersions = protocolVersions; - } - -} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java deleted file mode 100644 index d6ec2cc30..000000000 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java +++ /dev/null @@ -1,1138 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.server; - -import java.time.Duration; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.function.BiConsumer; -import java.util.function.BiFunction; - -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpSchema.CallToolResult; -import io.modelcontextprotocol.spec.McpSchema.ResourceTemplate; -import io.modelcontextprotocol.spec.McpServerTransportProvider; -import io.modelcontextprotocol.util.Assert; -import io.modelcontextprotocol.util.DeafaultMcpUriTemplateManagerFactory; -import io.modelcontextprotocol.util.McpUriTemplateManagerFactory; -import reactor.core.publisher.Mono; - -/** - * Factory class for creating Model Context Protocol (MCP) servers. MCP servers expose - * tools, resources, and prompts to AI models through a standardized interface. - * - *

- * This class serves as the main entry point for implementing the server-side of the MCP - * specification. The server's responsibilities include: - *

    - *
  • Exposing tools that models can invoke to perform actions - *
  • Providing access to resources that give models context - *
  • Managing prompt templates for structured model interactions - *
  • Handling client connections and requests - *
  • Implementing capability negotiation - *
- * - *

- * Thread Safety: Both synchronous and asynchronous server implementations are - * thread-safe. The synchronous server processes requests sequentially, while the - * asynchronous server can handle concurrent requests safely through its reactive - * programming model. - * - *

- * Error Handling: The server implementations provide robust error handling through the - * McpError class. Errors are properly propagated to clients while maintaining the - * server's stability. Server implementations should use appropriate error codes and - * provide meaningful error messages to help diagnose issues. - * - *

- * The class provides factory methods to create either: - *

    - *
  • {@link McpAsyncServer} for non-blocking operations with reactive responses - *
  • {@link McpSyncServer} for blocking operations with direct responses - *
- * - *

- * Example of creating a basic synchronous server:

{@code
- * McpServer.sync(transportProvider)
- *     .serverInfo("my-server", "1.0.0")
- *     .tool(new Tool("calculator", "Performs calculations", schema),
- *           (exchange, args) -> new CallToolResult("Result: " + calculate(args)))
- *     .build();
- * }
- * - * Example of creating a basic asynchronous server:
{@code
- * McpServer.async(transportProvider)
- *     .serverInfo("my-server", "1.0.0")
- *     .tool(new Tool("calculator", "Performs calculations", schema),
- *           (exchange, args) -> Mono.fromSupplier(() -> calculate(args))
- *               .map(result -> new CallToolResult("Result: " + result)))
- *     .build();
- * }
- * - *

- * Example with comprehensive asynchronous configuration:

{@code
- * McpServer.async(transportProvider)
- *     .serverInfo("advanced-server", "2.0.0")
- *     .capabilities(new ServerCapabilities(...))
- *     // Register tools
- *     .tools(
- *         new McpServerFeatures.AsyncToolSpecification(calculatorTool,
- *             (exchange, args) -> Mono.fromSupplier(() -> calculate(args))
- *                 .map(result -> new CallToolResult("Result: " + result))),
- *         new McpServerFeatures.AsyncToolSpecification(weatherTool,
- *             (exchange, args) -> Mono.fromSupplier(() -> getWeather(args))
- *                 .map(result -> new CallToolResult("Weather: " + result)))
- *     )
- *     // Register resources
- *     .resources(
- *         new McpServerFeatures.AsyncResourceSpecification(fileResource,
- *             (exchange, req) -> Mono.fromSupplier(() -> readFile(req))
- *                 .map(ReadResourceResult::new)),
- *         new McpServerFeatures.AsyncResourceSpecification(dbResource,
- *             (exchange, req) -> Mono.fromSupplier(() -> queryDb(req))
- *                 .map(ReadResourceResult::new))
- *     )
- *     // Add resource templates
- *     .resourceTemplates(
- *         new ResourceTemplate("file://{path}", "Access files"),
- *         new ResourceTemplate("db://{table}", "Access database")
- *     )
- *     // Register prompts
- *     .prompts(
- *         new McpServerFeatures.AsyncPromptSpecification(analysisPrompt,
- *             (exchange, req) -> Mono.fromSupplier(() -> generateAnalysisPrompt(req))
- *                 .map(GetPromptResult::new)),
- *         new McpServerFeatures.AsyncPromptRegistration(summaryPrompt,
- *             (exchange, req) -> Mono.fromSupplier(() -> generateSummaryPrompt(req))
- *                 .map(GetPromptResult::new))
- *     )
- *     .build();
- * }
- * - * @author Christian Tzolov - * @author Dariusz Jędrzejczyk - * @author Jihoon Kim - * @see McpAsyncServer - * @see McpSyncServer - * @see McpServerTransportProvider - */ -public interface McpServer { - - /** - * Starts building a synchronous MCP server that provides blocking operations. - * Synchronous servers block the current Thread's execution upon each request before - * giving the control back to the caller, making them simpler to implement but - * potentially less scalable for concurrent operations. - * @param transportProvider The transport layer implementation for MCP communication. - * @return A new instance of {@link SyncSpecification} for configuring the server. - */ - static SyncSpecification sync(McpServerTransportProvider transportProvider) { - return new SyncSpecification(transportProvider); - } - - /** - * Starts building an asynchronous MCP server that provides non-blocking operations. - * Asynchronous servers can handle multiple requests concurrently on a single Thread - * using a functional paradigm with non-blocking server transports, making them more - * scalable for high-concurrency scenarios but more complex to implement. - * @param transportProvider The transport layer implementation for MCP communication. - * @return A new instance of {@link AsyncSpecification} for configuring the server. - */ - static AsyncSpecification async(McpServerTransportProvider transportProvider) { - return new AsyncSpecification(transportProvider); - } - - /** - * Asynchronous server specification. - */ - class AsyncSpecification { - - private static final McpSchema.Implementation DEFAULT_SERVER_INFO = new McpSchema.Implementation("mcp-server", - "1.0.0"); - - private final McpServerTransportProvider transportProvider; - - private McpUriTemplateManagerFactory uriTemplateManagerFactory = new DeafaultMcpUriTemplateManagerFactory(); - - private ObjectMapper objectMapper; - - private McpSchema.Implementation serverInfo = DEFAULT_SERVER_INFO; - - private McpSchema.ServerCapabilities serverCapabilities; - - private String instructions; - - /** - * The Model Context Protocol (MCP) allows servers to expose tools that can be - * invoked by language models. Tools enable models to interact with external - * systems, such as querying databases, calling APIs, or performing computations. - * Each tool is uniquely identified by a name and includes metadata describing its - * schema. - */ - private final List tools = new ArrayList<>(); - - /** - * The Model Context Protocol (MCP) provides a standardized way for servers to - * expose resources to clients. Resources allow servers to share data that - * provides context to language models, such as files, database schemas, or - * application-specific information. Each resource is uniquely identified by a - * URI. - */ - private final Map resources = new HashMap<>(); - - private final List resourceTemplates = new ArrayList<>(); - - /** - * The Model Context Protocol (MCP) provides a standardized way for servers to - * expose prompt templates to clients. Prompts allow servers to provide structured - * messages and instructions for interacting with language models. Clients can - * discover available prompts, retrieve their contents, and provide arguments to - * customize them. - */ - private final Map prompts = new HashMap<>(); - - private final Map completions = new HashMap<>(); - - private final List, Mono>> rootsChangeHandlers = new ArrayList<>(); - - private Duration requestTimeout = Duration.ofSeconds(10); // Default timeout - - private AsyncSpecification(McpServerTransportProvider transportProvider) { - Assert.notNull(transportProvider, "Transport provider must not be null"); - this.transportProvider = transportProvider; - } - - /** - * Sets the URI template manager factory to use for creating URI templates. This - * allows for custom URI template parsing and variable extraction. - * @param uriTemplateManagerFactory The factory to use. Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if uriTemplateManagerFactory is null - */ - public AsyncSpecification uriTemplateManagerFactory(McpUriTemplateManagerFactory uriTemplateManagerFactory) { - Assert.notNull(uriTemplateManagerFactory, "URI template manager factory must not be null"); - this.uriTemplateManagerFactory = uriTemplateManagerFactory; - return this; - } - - /** - * Sets the duration to wait for server responses before timing out requests. This - * timeout applies to all requests made through the client, including tool calls, - * resource access, and prompt operations. - * @param requestTimeout The duration to wait before timing out requests. Must not - * be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if requestTimeout is null - */ - public AsyncSpecification requestTimeout(Duration requestTimeout) { - Assert.notNull(requestTimeout, "Request timeout must not be null"); - this.requestTimeout = requestTimeout; - return this; - } - - /** - * Sets the server implementation information that will be shared with clients - * during connection initialization. This helps with version compatibility, - * debugging, and server identification. - * @param serverInfo The server implementation details including name and version. - * Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if serverInfo is null - */ - public AsyncSpecification serverInfo(McpSchema.Implementation serverInfo) { - Assert.notNull(serverInfo, "Server info must not be null"); - this.serverInfo = serverInfo; - return this; - } - - /** - * Sets the server implementation information using name and version strings. This - * is a convenience method alternative to - * {@link #serverInfo(McpSchema.Implementation)}. - * @param name The server name. Must not be null or empty. - * @param version The server version. Must not be null or empty. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if name or version is null or empty - * @see #serverInfo(McpSchema.Implementation) - */ - public AsyncSpecification serverInfo(String name, String version) { - Assert.hasText(name, "Name must not be null or empty"); - Assert.hasText(version, "Version must not be null or empty"); - this.serverInfo = new McpSchema.Implementation(name, version); - return this; - } - - /** - * Sets the server instructions that will be shared with clients during connection - * initialization. These instructions provide guidance to the client on how to - * interact with this server. - * @param instructions The instructions text. Can be null or empty. - * @return This builder instance for method chaining - */ - public AsyncSpecification instructions(String instructions) { - this.instructions = instructions; - return this; - } - - /** - * Sets the server capabilities that will be advertised to clients during - * connection initialization. Capabilities define what features the server - * supports, such as: - *
    - *
  • Tool execution - *
  • Resource access - *
  • Prompt handling - *
- * @param serverCapabilities The server capabilities configuration. Must not be - * null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if serverCapabilities is null - */ - public AsyncSpecification capabilities(McpSchema.ServerCapabilities serverCapabilities) { - Assert.notNull(serverCapabilities, "Server capabilities must not be null"); - this.serverCapabilities = serverCapabilities; - return this; - } - - /** - * Adds a single tool with its implementation handler to the server. This is a - * convenience method for registering individual tools without creating a - * {@link McpServerFeatures.AsyncToolSpecification} explicitly. - * - *

- * Example usage:

{@code
-		 * .tool(
-		 *     new Tool("calculator", "Performs calculations", schema),
-		 *     (exchange, args) -> Mono.fromSupplier(() -> calculate(args))
-		 *         .map(result -> new CallToolResult("Result: " + result))
-		 * )
-		 * }
- * @param tool The tool definition including name, description, and schema. Must - * not be null. - * @param handler The function that implements the tool's logic. Must not be null. - * The function's first argument is an {@link McpAsyncServerExchange} upon which - * the server can interact with the connected client. The second argument is the - * map of arguments passed to the tool. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if tool or handler is null - */ - public AsyncSpecification tool(McpSchema.Tool tool, - BiFunction, Mono> handler) { - Assert.notNull(tool, "Tool must not be null"); - Assert.notNull(handler, "Handler must not be null"); - - this.tools.add(new McpServerFeatures.AsyncToolSpecification(tool, handler)); - - return this; - } - - /** - * Adds multiple tools with their handlers to the server using a List. This method - * is useful when tools are dynamically generated or loaded from a configuration - * source. - * @param toolSpecifications The list of tool specifications to add. Must not be - * null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if toolSpecifications is null - * @see #tools(McpServerFeatures.AsyncToolSpecification...) - */ - public AsyncSpecification tools(List toolSpecifications) { - Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); - this.tools.addAll(toolSpecifications); - return this; - } - - /** - * Adds multiple tools with their handlers to the server using varargs. This - * method provides a convenient way to register multiple tools inline. - * - *

- * Example usage:

{@code
-		 * .tools(
-		 *     new McpServerFeatures.AsyncToolSpecification(calculatorTool, calculatorHandler),
-		 *     new McpServerFeatures.AsyncToolSpecification(weatherTool, weatherHandler),
-		 *     new McpServerFeatures.AsyncToolSpecification(fileManagerTool, fileManagerHandler)
-		 * )
-		 * }
- * @param toolSpecifications The tool specifications to add. Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if toolSpecifications is null - * @see #tools(List) - */ - public AsyncSpecification tools(McpServerFeatures.AsyncToolSpecification... toolSpecifications) { - Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); - for (McpServerFeatures.AsyncToolSpecification tool : toolSpecifications) { - this.tools.add(tool); - } - return this; - } - - /** - * Registers multiple resources with their handlers using a Map. This method is - * useful when resources are dynamically generated or loaded from a configuration - * source. - * @param resourceSpecifications Map of resource name to specification. Must not - * be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if resourceSpecifications is null - * @see #resources(McpServerFeatures.AsyncResourceSpecification...) - */ - public AsyncSpecification resources( - Map resourceSpecifications) { - Assert.notNull(resourceSpecifications, "Resource handlers map must not be null"); - this.resources.putAll(resourceSpecifications); - return this; - } - - /** - * Registers multiple resources with their handlers using a List. This method is - * useful when resources need to be added in bulk from a collection. - * @param resourceSpecifications List of resource specifications. Must not be - * null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if resourceSpecifications is null - * @see #resources(McpServerFeatures.AsyncResourceSpecification...) - */ - public AsyncSpecification resources(List resourceSpecifications) { - Assert.notNull(resourceSpecifications, "Resource handlers list must not be null"); - for (McpServerFeatures.AsyncResourceSpecification resource : resourceSpecifications) { - this.resources.put(resource.resource().uri(), resource); - } - return this; - } - - /** - * Registers multiple resources with their handlers using varargs. This method - * provides a convenient way to register multiple resources inline. - * - *

- * Example usage:

{@code
-		 * .resources(
-		 *     new McpServerFeatures.AsyncResourceSpecification(fileResource, fileHandler),
-		 *     new McpServerFeatures.AsyncResourceSpecification(dbResource, dbHandler),
-		 *     new McpServerFeatures.AsyncResourceSpecification(apiResource, apiHandler)
-		 * )
-		 * }
- * @param resourceSpecifications The resource specifications to add. Must not be - * null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if resourceSpecifications is null - */ - public AsyncSpecification resources(McpServerFeatures.AsyncResourceSpecification... resourceSpecifications) { - Assert.notNull(resourceSpecifications, "Resource handlers list must not be null"); - for (McpServerFeatures.AsyncResourceSpecification resource : resourceSpecifications) { - this.resources.put(resource.resource().uri(), resource); - } - return this; - } - - /** - * Sets the resource templates that define patterns for dynamic resource access. - * Templates use URI patterns with placeholders that can be filled at runtime. - * - *

- * Example usage:

{@code
-		 * .resourceTemplates(
-		 *     new ResourceTemplate("file://{path}", "Access files by path"),
-		 *     new ResourceTemplate("db://{table}/{id}", "Access database records")
-		 * )
-		 * }
- * @param resourceTemplates List of resource templates. If null, clears existing - * templates. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if resourceTemplates is null. - * @see #resourceTemplates(ResourceTemplate...) - */ - public AsyncSpecification resourceTemplates(List resourceTemplates) { - Assert.notNull(resourceTemplates, "Resource templates must not be null"); - this.resourceTemplates.addAll(resourceTemplates); - return this; - } - - /** - * Sets the resource templates using varargs for convenience. This is an - * alternative to {@link #resourceTemplates(List)}. - * @param resourceTemplates The resource templates to set. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if resourceTemplates is null. - * @see #resourceTemplates(List) - */ - public AsyncSpecification resourceTemplates(ResourceTemplate... resourceTemplates) { - Assert.notNull(resourceTemplates, "Resource templates must not be null"); - for (ResourceTemplate resourceTemplate : resourceTemplates) { - this.resourceTemplates.add(resourceTemplate); - } - return this; - } - - /** - * Registers multiple prompts with their handlers using a Map. This method is - * useful when prompts are dynamically generated or loaded from a configuration - * source. - * - *

- * Example usage:

{@code
-		 * .prompts(Map.of("analysis", new McpServerFeatures.AsyncPromptSpecification(
-		 *     new Prompt("analysis", "Code analysis template"),
-		 *     request -> Mono.fromSupplier(() -> generateAnalysisPrompt(request))
-		 *         .map(GetPromptResult::new)
-		 * )));
-		 * }
- * @param prompts Map of prompt name to specification. Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if prompts is null - */ - public AsyncSpecification prompts(Map prompts) { - Assert.notNull(prompts, "Prompts map must not be null"); - this.prompts.putAll(prompts); - return this; - } - - /** - * Registers multiple prompts with their handlers using a List. This method is - * useful when prompts need to be added in bulk from a collection. - * @param prompts List of prompt specifications. Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if prompts is null - * @see #prompts(McpServerFeatures.AsyncPromptSpecification...) - */ - public AsyncSpecification prompts(List prompts) { - Assert.notNull(prompts, "Prompts list must not be null"); - for (McpServerFeatures.AsyncPromptSpecification prompt : prompts) { - this.prompts.put(prompt.prompt().name(), prompt); - } - return this; - } - - /** - * Registers multiple prompts with their handlers using varargs. This method - * provides a convenient way to register multiple prompts inline. - * - *

- * Example usage:

{@code
-		 * .prompts(
-		 *     new McpServerFeatures.AsyncPromptSpecification(analysisPrompt, analysisHandler),
-		 *     new McpServerFeatures.AsyncPromptSpecification(summaryPrompt, summaryHandler),
-		 *     new McpServerFeatures.AsyncPromptSpecification(reviewPrompt, reviewHandler)
-		 * )
-		 * }
- * @param prompts The prompt specifications to add. Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if prompts is null - */ - public AsyncSpecification prompts(McpServerFeatures.AsyncPromptSpecification... prompts) { - Assert.notNull(prompts, "Prompts list must not be null"); - for (McpServerFeatures.AsyncPromptSpecification prompt : prompts) { - this.prompts.put(prompt.prompt().name(), prompt); - } - return this; - } - - /** - * Registers multiple completions with their handlers using a List. This method is - * useful when completions need to be added in bulk from a collection. - * @param completions List of completion specifications. Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if completions is null - */ - public AsyncSpecification completions(List completions) { - Assert.notNull(completions, "Completions list must not be null"); - for (McpServerFeatures.AsyncCompletionSpecification completion : completions) { - this.completions.put(completion.referenceKey(), completion); - } - return this; - } - - /** - * Registers multiple completions with their handlers using varargs. This method - * is useful when completions are defined inline and added directly. - * @param completions Array of completion specifications. Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if completions is null - */ - public AsyncSpecification completions(McpServerFeatures.AsyncCompletionSpecification... completions) { - Assert.notNull(completions, "Completions list must not be null"); - for (McpServerFeatures.AsyncCompletionSpecification completion : completions) { - this.completions.put(completion.referenceKey(), completion); - } - return this; - } - - /** - * Registers a consumer that will be notified when the list of roots changes. This - * is useful for updating resource availability dynamically, such as when new - * files are added or removed. - * @param handler The handler to register. Must not be null. The function's first - * argument is an {@link McpAsyncServerExchange} upon which the server can - * interact with the connected client. The second argument is the list of roots. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if consumer is null - */ - public AsyncSpecification rootsChangeHandler( - BiFunction, Mono> handler) { - Assert.notNull(handler, "Consumer must not be null"); - this.rootsChangeHandlers.add(handler); - return this; - } - - /** - * Registers multiple consumers that will be notified when the list of roots - * changes. This method is useful when multiple consumers need to be registered at - * once. - * @param handlers The list of handlers to register. Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if consumers is null - * @see #rootsChangeHandler(BiFunction) - */ - public AsyncSpecification rootsChangeHandlers( - List, Mono>> handlers) { - Assert.notNull(handlers, "Handlers list must not be null"); - this.rootsChangeHandlers.addAll(handlers); - return this; - } - - /** - * Registers multiple consumers that will be notified when the list of roots - * changes using varargs. This method provides a convenient way to register - * multiple consumers inline. - * @param handlers The handlers to register. Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if consumers is null - * @see #rootsChangeHandlers(List) - */ - public AsyncSpecification rootsChangeHandlers( - @SuppressWarnings("unchecked") BiFunction, Mono>... handlers) { - Assert.notNull(handlers, "Handlers list must not be null"); - return this.rootsChangeHandlers(Arrays.asList(handlers)); - } - - /** - * Sets the object mapper to use for serializing and deserializing JSON messages. - * @param objectMapper the instance to use. Must not be null. - * @return This builder instance for method chaining. - * @throws IllegalArgumentException if objectMapper is null - */ - public AsyncSpecification objectMapper(ObjectMapper objectMapper) { - Assert.notNull(objectMapper, "ObjectMapper must not be null"); - this.objectMapper = objectMapper; - return this; - } - - /** - * Builds an asynchronous MCP server that provides non-blocking operations. - * @return A new instance of {@link McpAsyncServer} configured with this builder's - * settings. - */ - public McpAsyncServer build() { - var features = new McpServerFeatures.Async(this.serverInfo, this.serverCapabilities, this.tools, - this.resources, this.resourceTemplates, this.prompts, this.completions, this.rootsChangeHandlers, - this.instructions); - var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); - return new McpAsyncServer(this.transportProvider, mapper, features, this.requestTimeout, - this.uriTemplateManagerFactory); - } - - } - - /** - * Synchronous server specification. - */ - class SyncSpecification { - - private static final McpSchema.Implementation DEFAULT_SERVER_INFO = new McpSchema.Implementation("mcp-server", - "1.0.0"); - - private McpUriTemplateManagerFactory uriTemplateManagerFactory = new DeafaultMcpUriTemplateManagerFactory(); - - private final McpServerTransportProvider transportProvider; - - private ObjectMapper objectMapper; - - private McpSchema.Implementation serverInfo = DEFAULT_SERVER_INFO; - - private McpSchema.ServerCapabilities serverCapabilities; - - private String instructions; - - /** - * The Model Context Protocol (MCP) allows servers to expose tools that can be - * invoked by language models. Tools enable models to interact with external - * systems, such as querying databases, calling APIs, or performing computations. - * Each tool is uniquely identified by a name and includes metadata describing its - * schema. - */ - private final List tools = new ArrayList<>(); - - /** - * The Model Context Protocol (MCP) provides a standardized way for servers to - * expose resources to clients. Resources allow servers to share data that - * provides context to language models, such as files, database schemas, or - * application-specific information. Each resource is uniquely identified by a - * URI. - */ - private final Map resources = new HashMap<>(); - - private final List resourceTemplates = new ArrayList<>(); - - /** - * The Model Context Protocol (MCP) provides a standardized way for servers to - * expose prompt templates to clients. Prompts allow servers to provide structured - * messages and instructions for interacting with language models. Clients can - * discover available prompts, retrieve their contents, and provide arguments to - * customize them. - */ - private final Map prompts = new HashMap<>(); - - private final Map completions = new HashMap<>(); - - private final List>> rootsChangeHandlers = new ArrayList<>(); - - private Duration requestTimeout = Duration.ofSeconds(10); // Default timeout - - private SyncSpecification(McpServerTransportProvider transportProvider) { - Assert.notNull(transportProvider, "Transport provider must not be null"); - this.transportProvider = transportProvider; - } - - /** - * Sets the URI template manager factory to use for creating URI templates. This - * allows for custom URI template parsing and variable extraction. - * @param uriTemplateManagerFactory The factory to use. Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if uriTemplateManagerFactory is null - */ - public SyncSpecification uriTemplateManagerFactory(McpUriTemplateManagerFactory uriTemplateManagerFactory) { - Assert.notNull(uriTemplateManagerFactory, "URI template manager factory must not be null"); - this.uriTemplateManagerFactory = uriTemplateManagerFactory; - return this; - } - - /** - * Sets the duration to wait for server responses before timing out requests. This - * timeout applies to all requests made through the client, including tool calls, - * resource access, and prompt operations. - * @param requestTimeout The duration to wait before timing out requests. Must not - * be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if requestTimeout is null - */ - public SyncSpecification requestTimeout(Duration requestTimeout) { - Assert.notNull(requestTimeout, "Request timeout must not be null"); - this.requestTimeout = requestTimeout; - return this; - } - - /** - * Sets the server implementation information that will be shared with clients - * during connection initialization. This helps with version compatibility, - * debugging, and server identification. - * @param serverInfo The server implementation details including name and version. - * Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if serverInfo is null - */ - public SyncSpecification serverInfo(McpSchema.Implementation serverInfo) { - Assert.notNull(serverInfo, "Server info must not be null"); - this.serverInfo = serverInfo; - return this; - } - - /** - * Sets the server implementation information using name and version strings. This - * is a convenience method alternative to - * {@link #serverInfo(McpSchema.Implementation)}. - * @param name The server name. Must not be null or empty. - * @param version The server version. Must not be null or empty. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if name or version is null or empty - * @see #serverInfo(McpSchema.Implementation) - */ - public SyncSpecification serverInfo(String name, String version) { - Assert.hasText(name, "Name must not be null or empty"); - Assert.hasText(version, "Version must not be null or empty"); - this.serverInfo = new McpSchema.Implementation(name, version); - return this; - } - - /** - * Sets the server instructions that will be shared with clients during connection - * initialization. These instructions provide guidance to the client on how to - * interact with this server. - * @param instructions The instructions text. Can be null or empty. - * @return This builder instance for method chaining - */ - public SyncSpecification instructions(String instructions) { - this.instructions = instructions; - return this; - } - - /** - * Sets the server capabilities that will be advertised to clients during - * connection initialization. Capabilities define what features the server - * supports, such as: - *
    - *
  • Tool execution - *
  • Resource access - *
  • Prompt handling - *
- * @param serverCapabilities The server capabilities configuration. Must not be - * null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if serverCapabilities is null - */ - public SyncSpecification capabilities(McpSchema.ServerCapabilities serverCapabilities) { - Assert.notNull(serverCapabilities, "Server capabilities must not be null"); - this.serverCapabilities = serverCapabilities; - return this; - } - - /** - * Adds a single tool with its implementation handler to the server. This is a - * convenience method for registering individual tools without creating a - * {@link McpServerFeatures.SyncToolSpecification} explicitly. - * - *

- * Example usage:

{@code
-		 * .tool(
-		 *     new Tool("calculator", "Performs calculations", schema),
-		 *     (exchange, args) -> new CallToolResult("Result: " + calculate(args))
-		 * )
-		 * }
- * @param tool The tool definition including name, description, and schema. Must - * not be null. - * @param handler The function that implements the tool's logic. Must not be null. - * The function's first argument is an {@link McpSyncServerExchange} upon which - * the server can interact with the connected client. The second argument is the - * list of arguments passed to the tool. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if tool or handler is null - */ - public SyncSpecification tool(McpSchema.Tool tool, - BiFunction, McpSchema.CallToolResult> handler) { - Assert.notNull(tool, "Tool must not be null"); - Assert.notNull(handler, "Handler must not be null"); - - this.tools.add(new McpServerFeatures.SyncToolSpecification(tool, handler)); - - return this; - } - - /** - * Adds multiple tools with their handlers to the server using a List. This method - * is useful when tools are dynamically generated or loaded from a configuration - * source. - * @param toolSpecifications The list of tool specifications to add. Must not be - * null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if toolSpecifications is null - * @see #tools(McpServerFeatures.SyncToolSpecification...) - */ - public SyncSpecification tools(List toolSpecifications) { - Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); - this.tools.addAll(toolSpecifications); - return this; - } - - /** - * Adds multiple tools with their handlers to the server using varargs. This - * method provides a convenient way to register multiple tools inline. - * - *

- * Example usage:

{@code
-		 * .tools(
-		 *     new ToolSpecification(calculatorTool, calculatorHandler),
-		 *     new ToolSpecification(weatherTool, weatherHandler),
-		 *     new ToolSpecification(fileManagerTool, fileManagerHandler)
-		 * )
-		 * }
- * @param toolSpecifications The tool specifications to add. Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if toolSpecifications is null - * @see #tools(List) - */ - public SyncSpecification tools(McpServerFeatures.SyncToolSpecification... toolSpecifications) { - Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); - for (McpServerFeatures.SyncToolSpecification tool : toolSpecifications) { - this.tools.add(tool); - } - return this; - } - - /** - * Registers multiple resources with their handlers using a Map. This method is - * useful when resources are dynamically generated or loaded from a configuration - * source. - * @param resourceSpecifications Map of resource name to specification. Must not - * be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if resourceSpecifications is null - * @see #resources(McpServerFeatures.SyncResourceSpecification...) - */ - public SyncSpecification resources( - Map resourceSpecifications) { - Assert.notNull(resourceSpecifications, "Resource handlers map must not be null"); - this.resources.putAll(resourceSpecifications); - return this; - } - - /** - * Registers multiple resources with their handlers using a List. This method is - * useful when resources need to be added in bulk from a collection. - * @param resourceSpecifications List of resource specifications. Must not be - * null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if resourceSpecifications is null - * @see #resources(McpServerFeatures.SyncResourceSpecification...) - */ - public SyncSpecification resources(List resourceSpecifications) { - Assert.notNull(resourceSpecifications, "Resource handlers list must not be null"); - for (McpServerFeatures.SyncResourceSpecification resource : resourceSpecifications) { - this.resources.put(resource.resource().uri(), resource); - } - return this; - } - - /** - * Registers multiple resources with their handlers using varargs. This method - * provides a convenient way to register multiple resources inline. - * - *

- * Example usage:

{@code
-		 * .resources(
-		 *     new ResourceSpecification(fileResource, fileHandler),
-		 *     new ResourceSpecification(dbResource, dbHandler),
-		 *     new ResourceSpecification(apiResource, apiHandler)
-		 * )
-		 * }
- * @param resourceSpecifications The resource specifications to add. Must not be - * null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if resourceSpecifications is null - */ - public SyncSpecification resources(McpServerFeatures.SyncResourceSpecification... resourceSpecifications) { - Assert.notNull(resourceSpecifications, "Resource handlers list must not be null"); - for (McpServerFeatures.SyncResourceSpecification resource : resourceSpecifications) { - this.resources.put(resource.resource().uri(), resource); - } - return this; - } - - /** - * Sets the resource templates that define patterns for dynamic resource access. - * Templates use URI patterns with placeholders that can be filled at runtime. - * - *

- * Example usage:

{@code
-		 * .resourceTemplates(
-		 *     new ResourceTemplate("file://{path}", "Access files by path"),
-		 *     new ResourceTemplate("db://{table}/{id}", "Access database records")
-		 * )
-		 * }
- * @param resourceTemplates List of resource templates. If null, clears existing - * templates. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if resourceTemplates is null. - * @see #resourceTemplates(ResourceTemplate...) - */ - public SyncSpecification resourceTemplates(List resourceTemplates) { - Assert.notNull(resourceTemplates, "Resource templates must not be null"); - this.resourceTemplates.addAll(resourceTemplates); - return this; - } - - /** - * Sets the resource templates using varargs for convenience. This is an - * alternative to {@link #resourceTemplates(List)}. - * @param resourceTemplates The resource templates to set. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if resourceTemplates is null - * @see #resourceTemplates(List) - */ - public SyncSpecification resourceTemplates(ResourceTemplate... resourceTemplates) { - Assert.notNull(resourceTemplates, "Resource templates must not be null"); - for (ResourceTemplate resourceTemplate : resourceTemplates) { - this.resourceTemplates.add(resourceTemplate); - } - return this; - } - - /** - * Registers multiple prompts with their handlers using a Map. This method is - * useful when prompts are dynamically generated or loaded from a configuration - * source. - * - *

- * Example usage:

{@code
-		 * Map prompts = new HashMap<>();
-		 * prompts.put("analysis", new PromptSpecification(
-		 *     new Prompt("analysis", "Code analysis template"),
-		 *     (exchange, request) -> new GetPromptResult(generateAnalysisPrompt(request))
-		 * ));
-		 * .prompts(prompts)
-		 * }
- * @param prompts Map of prompt name to specification. Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if prompts is null - */ - public SyncSpecification prompts(Map prompts) { - Assert.notNull(prompts, "Prompts map must not be null"); - this.prompts.putAll(prompts); - return this; - } - - /** - * Registers multiple prompts with their handlers using a List. This method is - * useful when prompts need to be added in bulk from a collection. - * @param prompts List of prompt specifications. Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if prompts is null - * @see #prompts(McpServerFeatures.SyncPromptSpecification...) - */ - public SyncSpecification prompts(List prompts) { - Assert.notNull(prompts, "Prompts list must not be null"); - for (McpServerFeatures.SyncPromptSpecification prompt : prompts) { - this.prompts.put(prompt.prompt().name(), prompt); - } - return this; - } - - /** - * Registers multiple prompts with their handlers using varargs. This method - * provides a convenient way to register multiple prompts inline. - * - *

- * Example usage:

{@code
-		 * .prompts(
-		 *     new PromptSpecification(analysisPrompt, analysisHandler),
-		 *     new PromptSpecification(summaryPrompt, summaryHandler),
-		 *     new PromptSpecification(reviewPrompt, reviewHandler)
-		 * )
-		 * }
- * @param prompts The prompt specifications to add. Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if prompts is null - */ - public SyncSpecification prompts(McpServerFeatures.SyncPromptSpecification... prompts) { - Assert.notNull(prompts, "Prompts list must not be null"); - for (McpServerFeatures.SyncPromptSpecification prompt : prompts) { - this.prompts.put(prompt.prompt().name(), prompt); - } - return this; - } - - /** - * Registers multiple completions with their handlers using a List. This method is - * useful when completions need to be added in bulk from a collection. - * @param completions List of completion specifications. Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if completions is null - * @see #completions(McpServerFeatures.SyncCompletionSpecification...) - */ - public SyncSpecification completions(List completions) { - Assert.notNull(completions, "Completions list must not be null"); - for (McpServerFeatures.SyncCompletionSpecification completion : completions) { - this.completions.put(completion.referenceKey(), completion); - } - return this; - } - - /** - * Registers multiple completions with their handlers using varargs. This method - * is useful when completions are defined inline and added directly. - * @param completions Array of completion specifications. Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if completions is null - */ - public SyncSpecification completions(McpServerFeatures.SyncCompletionSpecification... completions) { - Assert.notNull(completions, "Completions list must not be null"); - for (McpServerFeatures.SyncCompletionSpecification completion : completions) { - this.completions.put(completion.referenceKey(), completion); - } - return this; - } - - /** - * Registers a consumer that will be notified when the list of roots changes. This - * is useful for updating resource availability dynamically, such as when new - * files are added or removed. - * @param handler The handler to register. Must not be null. The function's first - * argument is an {@link McpSyncServerExchange} upon which the server can interact - * with the connected client. The second argument is the list of roots. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if consumer is null - */ - public SyncSpecification rootsChangeHandler(BiConsumer> handler) { - Assert.notNull(handler, "Consumer must not be null"); - this.rootsChangeHandlers.add(handler); - return this; - } - - /** - * Registers multiple consumers that will be notified when the list of roots - * changes. This method is useful when multiple consumers need to be registered at - * once. - * @param handlers The list of handlers to register. Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if consumers is null - * @see #rootsChangeHandler(BiConsumer) - */ - public SyncSpecification rootsChangeHandlers( - List>> handlers) { - Assert.notNull(handlers, "Handlers list must not be null"); - this.rootsChangeHandlers.addAll(handlers); - return this; - } - - /** - * Registers multiple consumers that will be notified when the list of roots - * changes using varargs. This method provides a convenient way to register - * multiple consumers inline. - * @param handlers The handlers to register. Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if consumers is null - * @see #rootsChangeHandlers(List) - */ - public SyncSpecification rootsChangeHandlers( - BiConsumer>... handlers) { - Assert.notNull(handlers, "Handlers list must not be null"); - return this.rootsChangeHandlers(List.of(handlers)); - } - - /** - * Sets the object mapper to use for serializing and deserializing JSON messages. - * @param objectMapper the instance to use. Must not be null. - * @return This builder instance for method chaining. - * @throws IllegalArgumentException if objectMapper is null - */ - public SyncSpecification objectMapper(ObjectMapper objectMapper) { - Assert.notNull(objectMapper, "ObjectMapper must not be null"); - this.objectMapper = objectMapper; - return this; - } - - /** - * Builds a synchronous MCP server that provides blocking operations. - * @return A new instance of {@link McpSyncServer} configured with this builder's - * settings. - */ - public McpSyncServer build() { - McpServerFeatures.Sync syncFeatures = new McpServerFeatures.Sync(this.serverInfo, this.serverCapabilities, - this.tools, this.resources, this.resourceTemplates, this.prompts, this.completions, - this.rootsChangeHandlers, this.instructions); - McpServerFeatures.Async asyncFeatures = McpServerFeatures.Async.fromSync(syncFeatures); - var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); - var asyncServer = new McpAsyncServer(this.transportProvider, mapper, asyncFeatures, this.requestTimeout, - this.uriTemplateManagerFactory); - - return new McpSyncServer(asyncServer); - } - - } - -} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpError.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpError.java deleted file mode 100644 index 13e43240b..000000000 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpError.java +++ /dev/null @@ -1,25 +0,0 @@ -/* -* Copyright 2024 - 2024 the original author or authors. -*/ -package io.modelcontextprotocol.spec; - -import io.modelcontextprotocol.spec.McpSchema.JSONRPCResponse.JSONRPCError; - -public class McpError extends RuntimeException { - - private JSONRPCError jsonRpcError; - - public McpError(JSONRPCError jsonRpcError) { - super(jsonRpcError.message()); - this.jsonRpcError = jsonRpcError; - } - - public McpError(Object error) { - super(error.toString()); - } - - public JSONRPCError getJsonRpcError() { - return jsonRpcError; - } - -} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java deleted file mode 100644 index 597130946..000000000 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java +++ /dev/null @@ -1,1638 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.spec; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import com.fasterxml.jackson.annotation.JsonCreator; -import com.fasterxml.jackson.annotation.JsonIgnoreProperties; -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonProperty; -import com.fasterxml.jackson.annotation.JsonSubTypes; -import com.fasterxml.jackson.annotation.JsonTypeInfo; -import com.fasterxml.jackson.annotation.JsonTypeInfo.As; -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.util.Assert; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * Based on the JSON-RPC 2.0 - * specification and the Model - * Context Protocol Schema. - * - * @author Christian Tzolov - * @author Luca Chang - */ -public final class McpSchema { - - private static final Logger logger = LoggerFactory.getLogger(McpSchema.class); - - private McpSchema() { - } - - public static final String LATEST_PROTOCOL_VERSION = "2024-11-05"; - - public static final String JSONRPC_VERSION = "2.0"; - - public static final String FIRST_PAGE = null; - - // --------------------------- - // Method Names - // --------------------------- - - // Lifecycle Methods - public static final String METHOD_INITIALIZE = "initialize"; - - public static final String METHOD_NOTIFICATION_INITIALIZED = "notifications/initialized"; - - public static final String METHOD_PING = "ping"; - - // Tool Methods - public static final String METHOD_TOOLS_LIST = "tools/list"; - - public static final String METHOD_TOOLS_CALL = "tools/call"; - - public static final String METHOD_NOTIFICATION_TOOLS_LIST_CHANGED = "notifications/tools/list_changed"; - - // Resources Methods - public static final String METHOD_RESOURCES_LIST = "resources/list"; - - public static final String METHOD_RESOURCES_READ = "resources/read"; - - public static final String METHOD_NOTIFICATION_RESOURCES_LIST_CHANGED = "notifications/resources/list_changed"; - - public static final String METHOD_NOTIFICATION_RESOURCES_UPDATED = "notifications/resources/updated"; - - public static final String METHOD_RESOURCES_TEMPLATES_LIST = "resources/templates/list"; - - public static final String METHOD_RESOURCES_SUBSCRIBE = "resources/subscribe"; - - public static final String METHOD_RESOURCES_UNSUBSCRIBE = "resources/unsubscribe"; - - // Prompt Methods - public static final String METHOD_PROMPT_LIST = "prompts/list"; - - public static final String METHOD_PROMPT_GET = "prompts/get"; - - public static final String METHOD_NOTIFICATION_PROMPTS_LIST_CHANGED = "notifications/prompts/list_changed"; - - public static final String METHOD_COMPLETION_COMPLETE = "completion/complete"; - - // Logging Methods - public static final String METHOD_LOGGING_SET_LEVEL = "logging/setLevel"; - - public static final String METHOD_NOTIFICATION_MESSAGE = "notifications/message"; - - // Roots Methods - public static final String METHOD_ROOTS_LIST = "roots/list"; - - public static final String METHOD_NOTIFICATION_ROOTS_LIST_CHANGED = "notifications/roots/list_changed"; - - // Sampling Methods - public static final String METHOD_SAMPLING_CREATE_MESSAGE = "sampling/createMessage"; - - // Elicitation Methods - public static final String METHOD_ELICITATION_CREATE = "elicitation/create"; - - private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); - - // --------------------------- - // JSON-RPC Error Codes - // --------------------------- - /** - * Standard error codes used in MCP JSON-RPC responses. - */ - public static final class ErrorCodes { - - /** - * Invalid JSON was received by the server. - */ - public static final int PARSE_ERROR = -32700; - - /** - * The JSON sent is not a valid Request object. - */ - public static final int INVALID_REQUEST = -32600; - - /** - * The method does not exist / is not available. - */ - public static final int METHOD_NOT_FOUND = -32601; - - /** - * Invalid method parameter(s). - */ - public static final int INVALID_PARAMS = -32602; - - /** - * Internal JSON-RPC error. - */ - public static final int INTERNAL_ERROR = -32603; - - } - - public sealed interface Request permits InitializeRequest, CallToolRequest, CreateMessageRequest, ElicitRequest, - CompleteRequest, GetPromptRequest { - - } - - private static final TypeReference> MAP_TYPE_REF = new TypeReference<>() { - }; - - /** - * Deserializes a JSON string into a JSONRPCMessage object. - * @param objectMapper The ObjectMapper instance to use for deserialization - * @param jsonText The JSON string to deserialize - * @return A JSONRPCMessage instance using either the {@link JSONRPCRequest}, - * {@link JSONRPCNotification}, or {@link JSONRPCResponse} classes. - * @throws IOException If there's an error during deserialization - * @throws IllegalArgumentException If the JSON structure doesn't match any known - * message type - */ - public static JSONRPCMessage deserializeJsonRpcMessage(ObjectMapper objectMapper, String jsonText) - throws IOException { - - logger.debug("Received JSON message: {}", jsonText); - - var map = objectMapper.readValue(jsonText, MAP_TYPE_REF); - - // Determine message type based on specific JSON structure - if (map.containsKey("method") && map.containsKey("id")) { - return objectMapper.convertValue(map, JSONRPCRequest.class); - } - else if (map.containsKey("method") && !map.containsKey("id")) { - return objectMapper.convertValue(map, JSONRPCNotification.class); - } - else if (map.containsKey("result") || map.containsKey("error")) { - return objectMapper.convertValue(map, JSONRPCResponse.class); - } - - throw new IllegalArgumentException("Cannot deserialize JSONRPCMessage: " + jsonText); - } - - // --------------------------- - // JSON-RPC Message Types - // --------------------------- - public sealed interface JSONRPCMessage permits JSONRPCRequest, JSONRPCNotification, JSONRPCResponse { - - String jsonrpc(); - - } - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - // TODO: batching support - // @JsonFormat(with = JsonFormat.Feature.ACCEPT_SINGLE_VALUE_AS_ARRAY) - public record JSONRPCRequest( // @formatter:off - @JsonProperty("jsonrpc") String jsonrpc, - @JsonProperty("method") String method, - @JsonProperty("id") Object id, - @JsonProperty("params") Object params) implements JSONRPCMessage { - } // @formatter:on - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - // TODO: batching support - // @JsonFormat(with = JsonFormat.Feature.ACCEPT_SINGLE_VALUE_AS_ARRAY) - public record JSONRPCNotification( // @formatter:off - @JsonProperty("jsonrpc") String jsonrpc, - @JsonProperty("method") String method, - @JsonProperty("params") Object params) implements JSONRPCMessage { - } // @formatter:on - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - // TODO: batching support - // @JsonFormat(with = JsonFormat.Feature.ACCEPT_SINGLE_VALUE_AS_ARRAY) - public record JSONRPCResponse( // @formatter:off - @JsonProperty("jsonrpc") String jsonrpc, - @JsonProperty("id") Object id, - @JsonProperty("result") Object result, - @JsonProperty("error") JSONRPCError error) implements JSONRPCMessage { - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record JSONRPCError( - @JsonProperty("code") int code, - @JsonProperty("message") String message, - @JsonProperty("data") Object data) { - } - }// @formatter:on - - // --------------------------- - // Initialization - // --------------------------- - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record InitializeRequest( // @formatter:off - @JsonProperty("protocolVersion") String protocolVersion, - @JsonProperty("capabilities") ClientCapabilities capabilities, - @JsonProperty("clientInfo") Implementation clientInfo) implements Request { - } // @formatter:on - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record InitializeResult( // @formatter:off - @JsonProperty("protocolVersion") String protocolVersion, - @JsonProperty("capabilities") ServerCapabilities capabilities, - @JsonProperty("serverInfo") Implementation serverInfo, - @JsonProperty("instructions") String instructions) { - } // @formatter:on - - /** - * Clients can implement additional features to enrich connected MCP servers with - * additional capabilities. These capabilities can be used to extend the functionality - * of the server, or to provide additional information to the server about the - * client's capabilities. - * - * @param experimental WIP - * @param roots define the boundaries of where servers can operate within the - * filesystem, allowing them to understand which directories and files they have - * access to. - * @param sampling Provides a standardized way for servers to request LLM sampling - * (“completions” or “generations”) from language models via clients. - * @param elicitation Provides a standardized way for servers to request additional - * information from users through the client during interactions. - * - */ - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record ClientCapabilities( // @formatter:off - @JsonProperty("experimental") Map experimental, - @JsonProperty("roots") RootCapabilities roots, - @JsonProperty("sampling") Sampling sampling, - @JsonProperty("elicitation") Elicitation elicitation) { - - /** - * Roots define the boundaries of where servers can operate within the filesystem, - * allowing them to understand which directories and files they have access to. - * Servers can request the list of roots from supporting clients and - * receive notifications when that list changes. - * - * @param listChanged Whether the client would send notification about roots - * has changed since the last time the server checked. - */ - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record RootCapabilities( - @JsonProperty("listChanged") Boolean listChanged) { - } - - /** - * Provides a standardized way for servers to request LLM - * sampling ("completions" or "generations") from language - * models via clients. This flow allows clients to maintain - * control over model access, selection, and permissions - * while enabling servers to leverage AI capabilities—with - * no server API keys necessary. Servers can request text or - * image-based interactions and optionally include context - * from MCP servers in their prompts. - */ - @JsonInclude(JsonInclude.Include.NON_ABSENT) - public record Sampling() { - } - - /** - * Provides a standardized way for servers to request additional - * information from users through the client during interactions. - * This flow allows clients to maintain control over user - * interactions and data sharing while enabling servers to gather - * necessary information dynamically. Servers can request structured - * data from users with optional JSON schemas to validate responses. - */ - @JsonInclude(JsonInclude.Include.NON_ABSENT) - public record Elicitation() { - } - - public static Builder builder() { - return new Builder(); - } - - public static class Builder { - private Map experimental; - private RootCapabilities roots; - private Sampling sampling; - private Elicitation elicitation; - - public Builder experimental(Map experimental) { - this.experimental = experimental; - return this; - } - - public Builder roots(Boolean listChanged) { - this.roots = new RootCapabilities(listChanged); - return this; - } - - public Builder sampling() { - this.sampling = new Sampling(); - return this; - } - - public Builder elicitation() { - this.elicitation = new Elicitation(); - return this; - } - - public ClientCapabilities build() { - return new ClientCapabilities(experimental, roots, sampling, elicitation); - } - } - }// @formatter:on - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record ServerCapabilities( // @formatter:off - @JsonProperty("completions") CompletionCapabilities completions, - @JsonProperty("experimental") Map experimental, - @JsonProperty("logging") LoggingCapabilities logging, - @JsonProperty("prompts") PromptCapabilities prompts, - @JsonProperty("resources") ResourceCapabilities resources, - @JsonProperty("tools") ToolCapabilities tools) { - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - public record CompletionCapabilities() { - } - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - public record LoggingCapabilities() { - } - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - public record PromptCapabilities( - @JsonProperty("listChanged") Boolean listChanged) { - } - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - public record ResourceCapabilities( - @JsonProperty("subscribe") Boolean subscribe, - @JsonProperty("listChanged") Boolean listChanged) { - } - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - public record ToolCapabilities( - @JsonProperty("listChanged") Boolean listChanged) { - } - - public static Builder builder() { - return new Builder(); - } - - public static class Builder { - - private CompletionCapabilities completions; - private Map experimental; - private LoggingCapabilities logging = new LoggingCapabilities(); - private PromptCapabilities prompts; - private ResourceCapabilities resources; - private ToolCapabilities tools; - - public Builder completions() { - this.completions = new CompletionCapabilities(); - return this; - } - - public Builder experimental(Map experimental) { - this.experimental = experimental; - return this; - } - - public Builder logging() { - this.logging = new LoggingCapabilities(); - return this; - } - - public Builder prompts(Boolean listChanged) { - this.prompts = new PromptCapabilities(listChanged); - return this; - } - - public Builder resources(Boolean subscribe, Boolean listChanged) { - this.resources = new ResourceCapabilities(subscribe, listChanged); - return this; - } - - public Builder tools(Boolean listChanged) { - this.tools = new ToolCapabilities(listChanged); - return this; - } - - public ServerCapabilities build() { - return new ServerCapabilities(completions, experimental, logging, prompts, resources, tools); - } - } - } // @formatter:on - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record Implementation(// @formatter:off - @JsonProperty("name") String name, - @JsonProperty("version") String version) { - } // @formatter:on - - // Existing Enums and Base Types (from previous implementation) - public enum Role {// @formatter:off - - @JsonProperty("user") USER, - @JsonProperty("assistant") ASSISTANT - }// @formatter:on - - // --------------------------- - // Resource Interfaces - // --------------------------- - /** - * Base for objects that include optional annotations for the client. The client can - * use annotations to inform how objects are used or displayed - */ - public interface Annotated { - - Annotations annotations(); - - } - - /** - * Optional annotations for the client. The client can use annotations to inform how - * objects are used or displayed. - * - * @param audience Describes who the intended customer of this object or data is. It - * can include multiple entries to indicate content useful for multiple audiences - * (e.g., `["user", "assistant"]`). - * @param priority Describes how important this data is for operating the server. A - * value of 1 means "most important," and indicates that the data is effectively - * required, while 0 means "least important," and indicates that the data is entirely - * optional. It is a number between 0 and 1. - */ - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record Annotations( // @formatter:off - @JsonProperty("audience") List audience, - @JsonProperty("priority") Double priority) { - } // @formatter:on - - /** - * A known resource that the server is capable of reading. - * - * @param uri the URI of the resource. - * @param name A human-readable name for this resource. This can be used by clients to - * populate UI elements. - * @param description A description of what this resource represents. This can be used - * by clients to improve the LLM's understanding of available resources. It can be - * thought of like a "hint" to the model. - * @param mimeType The MIME type of this resource, if known. - * @param size The size of the raw resource content, in bytes (i.e., before base64 - * encoding or any tokenization), if known. This can be used by Hosts to display file - * sizes and estimate context window usage. - * @param annotations Optional annotations for the client. The client can use - * annotations to inform how objects are used or displayed. - */ - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record Resource( // @formatter:off - @JsonProperty("uri") String uri, - @JsonProperty("name") String name, - @JsonProperty("description") String description, - @JsonProperty("mimeType") String mimeType, - @JsonProperty("size") Long size, - @JsonProperty("annotations") Annotations annotations) implements Annotated { - - /** - * @deprecated Only exists for backwards-compatibility purposes. Use - * {@link Resource#builder()} instead. - */ - @Deprecated - public Resource(String uri, String name, String description, String mimeType, Annotations annotations) { - this(uri, name, description, mimeType, null, annotations); - } - - public static Builder builder() { - return new Builder(); - } - - public static class Builder { - private String uri; - private String name; - private String description; - private String mimeType; - private Long size; - private Annotations annotations; - - public Builder uri(String uri) { - this.uri = uri; - return this; - } - - public Builder name(String name) { - this.name = name; - return this; - } - - public Builder description(String description) { - this.description = description; - return this; - } - - public Builder mimeType(String mimeType) { - this.mimeType = mimeType; - return this; - } - - public Builder size(Long size) { - this.size = size; - return this; - } - - public Builder annotations(Annotations annotations) { - this.annotations = annotations; - return this; - } - - public Resource build() { - Assert.hasText(uri, "uri must not be empty"); - Assert.hasText(name, "name must not be empty"); - - return new Resource(uri, name, description, mimeType, size, annotations); - } - } - } // @formatter:on - - /** - * Resource templates allow servers to expose parameterized resources using URI - * templates. - * - * @param uriTemplate A URI template that can be used to generate URIs for this - * resource. - * @param name A human-readable name for this resource. This can be used by clients to - * populate UI elements. - * @param description A description of what this resource represents. This can be used - * by clients to improve the LLM's understanding of available resources. It can be - * thought of like a "hint" to the model. - * @param mimeType The MIME type of this resource, if known. - * @param annotations Optional annotations for the client. The client can use - * annotations to inform how objects are used or displayed. - * @see RFC 6570 - */ - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record ResourceTemplate( // @formatter:off - @JsonProperty("uriTemplate") String uriTemplate, - @JsonProperty("name") String name, - @JsonProperty("description") String description, - @JsonProperty("mimeType") String mimeType, - @JsonProperty("annotations") Annotations annotations) implements Annotated { - } // @formatter:on - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record ListResourcesResult( // @formatter:off - @JsonProperty("resources") List resources, - @JsonProperty("nextCursor") String nextCursor) { - } // @formatter:on - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record ListResourceTemplatesResult( // @formatter:off - @JsonProperty("resourceTemplates") List resourceTemplates, - @JsonProperty("nextCursor") String nextCursor) { - } // @formatter:on - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record ReadResourceRequest( // @formatter:off - @JsonProperty("uri") String uri){ - } // @formatter:on - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record ReadResourceResult( // @formatter:off - @JsonProperty("contents") List contents){ - } // @formatter:on - - /** - * Sent from the client to request resources/updated notifications from the server - * whenever a particular resource changes. - * - * @param uri the URI of the resource to subscribe to. The URI can use any protocol; - * it is up to the server how to interpret it. - */ - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record SubscribeRequest( // @formatter:off - @JsonProperty("uri") String uri){ - } // @formatter:on - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record UnsubscribeRequest( // @formatter:off - @JsonProperty("uri") String uri){ - } // @formatter:on - - /** - * The contents of a specific resource or sub-resource. - */ - @JsonTypeInfo(use = JsonTypeInfo.Id.DEDUCTION, include = As.PROPERTY) - @JsonSubTypes({ @JsonSubTypes.Type(value = TextResourceContents.class, name = "text"), - @JsonSubTypes.Type(value = BlobResourceContents.class, name = "blob") }) - public sealed interface ResourceContents permits TextResourceContents, BlobResourceContents { - - /** - * The URI of this resource. - * @return the URI of this resource. - */ - String uri(); - - /** - * The MIME type of this resource. - * @return the MIME type of this resource. - */ - String mimeType(); - - } - - /** - * Text contents of a resource. - * - * @param uri the URI of this resource. - * @param mimeType the MIME type of this resource. - * @param text the text of the resource. This must only be set if the resource can - * actually be represented as text (not binary data). - */ - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record TextResourceContents( // @formatter:off - @JsonProperty("uri") String uri, - @JsonProperty("mimeType") String mimeType, - @JsonProperty("text") String text) implements ResourceContents { - } // @formatter:on - - /** - * Binary contents of a resource. - * - * @param uri the URI of this resource. - * @param mimeType the MIME type of this resource. - * @param blob a base64-encoded string representing the binary data of the resource. - * This must only be set if the resource can actually be represented as binary data - * (not text). - */ - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record BlobResourceContents( // @formatter:off - @JsonProperty("uri") String uri, - @JsonProperty("mimeType") String mimeType, - @JsonProperty("blob") String blob) implements ResourceContents { - } // @formatter:on - - // --------------------------- - // Prompt Interfaces - // --------------------------- - /** - * A prompt or prompt template that the server offers. - * - * @param name The name of the prompt or prompt template. - * @param description An optional description of what this prompt provides. - * @param arguments A list of arguments to use for templating the prompt. - */ - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record Prompt( // @formatter:off - @JsonProperty("name") String name, - @JsonProperty("description") String description, - @JsonProperty("arguments") List arguments) { - } // @formatter:on - - /** - * Describes an argument that a prompt can accept. - * - * @param name The name of the argument. - * @param description A human-readable description of the argument. - * @param required Whether this argument must be provided. - */ - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record PromptArgument( // @formatter:off - @JsonProperty("name") String name, - @JsonProperty("description") String description, - @JsonProperty("required") Boolean required) { - }// @formatter:on - - /** - * Describes a message returned as part of a prompt. - * - * This is similar to `SamplingMessage`, but also supports the embedding of resources - * from the MCP server. - * - * @param role The sender or recipient of messages and data in a conversation. - * @param content The content of the message of type {@link Content}. - */ - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record PromptMessage( // @formatter:off - @JsonProperty("role") Role role, - @JsonProperty("content") Content content) { - } // @formatter:on - - /** - * The server's response to a prompts/list request from the client. - * - * @param prompts A list of prompts that the server provides. - * @param nextCursor An optional cursor for pagination. If present, indicates there - * are more prompts available. - */ - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record ListPromptsResult( // @formatter:off - @JsonProperty("prompts") List prompts, - @JsonProperty("nextCursor") String nextCursor) { - }// @formatter:on - - /** - * Used by the client to get a prompt provided by the server. - * - * @param name The name of the prompt or prompt template. - * @param arguments Arguments to use for templating the prompt. - */ - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record GetPromptRequest(// @formatter:off - @JsonProperty("name") String name, - @JsonProperty("arguments") Map arguments) implements Request { - }// @formatter:off - - /** - * The server's response to a prompts/get request from the client. - * - * @param description An optional description for the prompt. - * @param messages A list of messages to display as part of the prompt. - */ - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record GetPromptResult( // @formatter:off - @JsonProperty("description") String description, - @JsonProperty("messages") List messages) { - } // @formatter:on - - // --------------------------- - // Tool Interfaces - // --------------------------- - /** - * The server's response to a tools/list request from the client. - * - * @param tools A list of tools that the server provides. - * @param nextCursor An optional cursor for pagination. If present, indicates there - * are more tools available. - */ - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record ListToolsResult( // @formatter:off - @JsonProperty("tools") List tools, - @JsonProperty("nextCursor") String nextCursor) { - }// @formatter:on - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record JsonSchema( // @formatter:off - @JsonProperty("type") String type, - @JsonProperty("properties") Map properties, - @JsonProperty("required") List required, - @JsonProperty("additionalProperties") Boolean additionalProperties, - @JsonProperty("$defs") Map defs, - @JsonProperty("definitions") Map definitions) { - } // @formatter:on - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record ToolAnnotations( // @formatter:off - @JsonProperty("title") String title, - @JsonProperty("readOnlyHint") Boolean readOnlyHint, - @JsonProperty("destructiveHint") Boolean destructiveHint, - @JsonProperty("idempotentHint") Boolean idempotentHint, - @JsonProperty("openWorldHint") Boolean openWorldHint, - @JsonProperty("returnDirect") Boolean returnDirect) { - } // @formatter:on - - /** - * Represents a tool that the server provides. Tools enable servers to expose - * executable functionality to the system. Through these tools, you can interact with - * external systems, perform computations, and take actions in the real world. - * - * @param name A unique identifier for the tool. This name is used when calling the - * tool. - * @param description A human-readable description of what the tool does. This can be - * used by clients to improve the LLM's understanding of available tools. - * @param inputSchema A JSON Schema object that describes the expected structure of - * the arguments when calling this tool. This allows clients to validate tool - * @param annotations Additional properties describing a Tool to clients. arguments - * before sending them to the server. - */ - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record Tool( // @formatter:off - @JsonProperty("name") String name, - @JsonProperty("description") String description, - @JsonProperty("inputSchema") JsonSchema inputSchema, - @JsonProperty("annotations") ToolAnnotations annotations) { - - public Tool(String name, String description, String schema) { - this(name, description, parseSchema(schema), null); - } - - public Tool(String name, String description, String schema, ToolAnnotations annotations) { - this(name, description, parseSchema(schema), annotations); - } - - } // @formatter:on - - private static JsonSchema parseSchema(String schema) { - try { - return OBJECT_MAPPER.readValue(schema, JsonSchema.class); - } - catch (IOException e) { - throw new IllegalArgumentException("Invalid schema: " + schema, e); - } - } - - /** - * Used by the client to call a tool provided by the server. - * - * @param name The name of the tool to call. This must match a tool name from - * tools/list. - * @param arguments Arguments to pass to the tool. These must conform to the tool's - * input schema. - */ - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record CallToolRequest(// @formatter:off - @JsonProperty("name") String name, - @JsonProperty("arguments") Map arguments) implements Request { - - public CallToolRequest(String name, String jsonArguments) { - this(name, parseJsonArguments(jsonArguments)); - } - - private static Map parseJsonArguments(String jsonArguments) { - try { - return OBJECT_MAPPER.readValue(jsonArguments, MAP_TYPE_REF); - } - catch (IOException e) { - throw new IllegalArgumentException("Invalid arguments: " + jsonArguments, e); - } - } - }// @formatter:off - - /** - * The server's response to a tools/call request from the client. - * - * @param content A list of content items representing the tool's output. Each item can be text, an image, - * or an embedded resource. - * @param isError If true, indicates that the tool execution failed and the content contains error information. - * If false or absent, indicates successful execution. - */ - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record CallToolResult( // @formatter:off - @JsonProperty("content") List content, - @JsonProperty("isError") Boolean isError) { - - /** - * Creates a new instance of {@link CallToolResult} with a string containing the - * tool result. - * - * @param content The content of the tool result. This will be mapped to a one-sized list - * with a {@link TextContent} element. - * @param isError If true, indicates that the tool execution failed and the content contains error information. - * If false or absent, indicates successful execution. - */ - public CallToolResult(String content, Boolean isError) { - this(List.of(new TextContent(content)), isError); - } - - /** - * Creates a builder for {@link CallToolResult}. - * @return a new builder instance - */ - public static Builder builder() { - return new Builder(); - } - - /** - * Builder for {@link CallToolResult}. - */ - public static class Builder { - private List content = new ArrayList<>(); - private Boolean isError; - - /** - * Sets the content list for the tool result. - * @param content the content list - * @return this builder - */ - public Builder content(List content) { - Assert.notNull(content, "content must not be null"); - this.content = content; - return this; - } - - /** - * Sets the text content for the tool result. - * @param textContent the text content - * @return this builder - */ - public Builder textContent(List textContent) { - Assert.notNull(textContent, "textContent must not be null"); - textContent.stream() - .map(TextContent::new) - .forEach(this.content::add); - return this; - } - - /** - * Adds a content item to the tool result. - * @param contentItem the content item to add - * @return this builder - */ - public Builder addContent(Content contentItem) { - Assert.notNull(contentItem, "contentItem must not be null"); - if (this.content == null) { - this.content = new ArrayList<>(); - } - this.content.add(contentItem); - return this; - } - - /** - * Adds a text content item to the tool result. - * @param text the text content - * @return this builder - */ - public Builder addTextContent(String text) { - Assert.notNull(text, "text must not be null"); - return addContent(new TextContent(text)); - } - - /** - * Sets whether the tool execution resulted in an error. - * @param isError true if the tool execution failed, false otherwise - * @return this builder - */ - public Builder isError(Boolean isError) { - Assert.notNull(isError, "isError must not be null"); - this.isError = isError; - return this; - } - - /** - * Builds a new {@link CallToolResult} instance. - * @return a new CallToolResult instance - */ - public CallToolResult build() { - return new CallToolResult(content, isError); - } - } - - } // @formatter:on - - // --------------------------- - // Sampling Interfaces - // --------------------------- - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record ModelPreferences(// @formatter:off - @JsonProperty("hints") List hints, - @JsonProperty("costPriority") Double costPriority, - @JsonProperty("speedPriority") Double speedPriority, - @JsonProperty("intelligencePriority") Double intelligencePriority) { - - public static Builder builder() { - return new Builder(); - } - - public static class Builder { - private List hints; - private Double costPriority; - private Double speedPriority; - private Double intelligencePriority; - - public Builder hints(List hints) { - this.hints = hints; - return this; - } - - public Builder addHint(String name) { - if (this.hints == null) { - this.hints = new ArrayList<>(); - } - this.hints.add(new ModelHint(name)); - return this; - } - - public Builder costPriority(Double costPriority) { - this.costPriority = costPriority; - return this; - } - - public Builder speedPriority(Double speedPriority) { - this.speedPriority = speedPriority; - return this; - } - - public Builder intelligencePriority(Double intelligencePriority) { - this.intelligencePriority = intelligencePriority; - return this; - } - - public ModelPreferences build() { - return new ModelPreferences(hints, costPriority, speedPriority, intelligencePriority); - } - } -} // @formatter:on - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record ModelHint(@JsonProperty("name") String name) { - public static ModelHint of(String name) { - return new ModelHint(name); - } - } - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record SamplingMessage(// @formatter:off - @JsonProperty("role") Role role, - @JsonProperty("content") Content content) { - } // @formatter:on - - // Sampling and Message Creation - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record CreateMessageRequest(// @formatter:off - @JsonProperty("messages") List messages, - @JsonProperty("modelPreferences") ModelPreferences modelPreferences, - @JsonProperty("systemPrompt") String systemPrompt, - @JsonProperty("includeContext") ContextInclusionStrategy includeContext, - @JsonProperty("temperature") Double temperature, - @JsonProperty("maxTokens") int maxTokens, - @JsonProperty("stopSequences") List stopSequences, - @JsonProperty("metadata") Map metadata) implements Request { - - public enum ContextInclusionStrategy { - @JsonProperty("none") NONE, - @JsonProperty("thisServer") THIS_SERVER, - @JsonProperty("allServers") ALL_SERVERS - } - - public static Builder builder() { - return new Builder(); - } - - public static class Builder { - private List messages; - private ModelPreferences modelPreferences; - private String systemPrompt; - private ContextInclusionStrategy includeContext; - private Double temperature; - private int maxTokens; - private List stopSequences; - private Map metadata; - - public Builder messages(List messages) { - this.messages = messages; - return this; - } - - public Builder modelPreferences(ModelPreferences modelPreferences) { - this.modelPreferences = modelPreferences; - return this; - } - - public Builder systemPrompt(String systemPrompt) { - this.systemPrompt = systemPrompt; - return this; - } - - public Builder includeContext(ContextInclusionStrategy includeContext) { - this.includeContext = includeContext; - return this; - } - - public Builder temperature(Double temperature) { - this.temperature = temperature; - return this; - } - - public Builder maxTokens(int maxTokens) { - this.maxTokens = maxTokens; - return this; - } - - public Builder stopSequences(List stopSequences) { - this.stopSequences = stopSequences; - return this; - } - - public Builder metadata(Map metadata) { - this.metadata = metadata; - return this; - } - - public CreateMessageRequest build() { - return new CreateMessageRequest(messages, modelPreferences, systemPrompt, - includeContext, temperature, maxTokens, stopSequences, metadata); - } - } - }// @formatter:on - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record CreateMessageResult(// @formatter:off - @JsonProperty("role") Role role, - @JsonProperty("content") Content content, - @JsonProperty("model") String model, - @JsonProperty("stopReason") StopReason stopReason) { - - public enum StopReason { - @JsonProperty("endTurn") END_TURN("endTurn"), - @JsonProperty("stopSequence") STOP_SEQUENCE("stopSequence"), - @JsonProperty("maxTokens") MAX_TOKENS("maxTokens"), - @JsonProperty("unknown") UNKNOWN("unknown"); - - private final String value; - - StopReason(String value) { - this.value = value; - } - - @JsonCreator - private static StopReason of(String value) { - return Arrays.stream(StopReason.values()) - .filter(stopReason -> stopReason.value.equals(value)) - .findFirst() - .orElse(StopReason.UNKNOWN); - } - } - - public static Builder builder() { - return new Builder(); - } - - public static class Builder { - private Role role = Role.ASSISTANT; - private Content content; - private String model; - private StopReason stopReason = StopReason.END_TURN; - - public Builder role(Role role) { - this.role = role; - return this; - } - - public Builder content(Content content) { - this.content = content; - return this; - } - - public Builder model(String model) { - this.model = model; - return this; - } - - public Builder stopReason(StopReason stopReason) { - this.stopReason = stopReason; - return this; - } - - public Builder message(String message) { - this.content = new TextContent(message); - return this; - } - - public CreateMessageResult build() { - return new CreateMessageResult(role, content, model, stopReason); - } - } - }// @formatter:on - - // Elicitation - /** - * Used by the server to send an elicitation to the client. - * - * @param message The body of the elicitation message. - * @param requestedSchema The elicitation response schema that must be satisfied. - */ - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record ElicitRequest(// @formatter:off - @JsonProperty("message") String message, - @JsonProperty("requestedSchema") Map requestedSchema) implements Request { - - public static Builder builder() { - return new Builder(); - } - - public static class Builder { - private String message; - private Map requestedSchema; - - public Builder message(String message) { - this.message = message; - return this; - } - - public Builder requestedSchema(Map requestedSchema) { - this.requestedSchema = requestedSchema; - return this; - } - - public ElicitRequest build() { - return new ElicitRequest(message, requestedSchema); - } - } - }// @formatter:on - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record ElicitResult(// @formatter:off - @JsonProperty("action") Action action, - @JsonProperty("content") Map content) { - - public enum Action { - @JsonProperty("accept") ACCEPT, - @JsonProperty("decline") DECLINE, - @JsonProperty("cancel") CANCEL - } - - public static Builder builder() { - return new Builder(); - } - - public static class Builder { - private Action action; - private Map content; - - public Builder message(Action action) { - this.action = action; - return this; - } - - public Builder content(Map content) { - this.content = content; - return this; - } - - public ElicitResult build() { - return new ElicitResult(action, content); - } - } - }// @formatter:on - - // --------------------------- - // Pagination Interfaces - // --------------------------- - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record PaginatedRequest(@JsonProperty("cursor") String cursor) { - } - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record PaginatedResult(@JsonProperty("nextCursor") String nextCursor) { - } - - // --------------------------- - // Progress and Logging - // --------------------------- - @JsonIgnoreProperties(ignoreUnknown = true) - public record ProgressNotification(// @formatter:off - @JsonProperty("progressToken") String progressToken, - @JsonProperty("progress") double progress, - @JsonProperty("total") Double total) { - }// @formatter:on - - /** - * The Model Context Protocol (MCP) provides a standardized way for servers to send - * resources update message to clients. - * - * @param uri The updated resource uri. - */ - @JsonIgnoreProperties(ignoreUnknown = true) - public record ResourcesUpdatedNotification(// @formatter:off - @JsonProperty("uri") String uri) { - }// @formatter:on - - /** - * The Model Context Protocol (MCP) provides a standardized way for servers to send - * structured log messages to clients. Clients can control logging verbosity by - * setting minimum log levels, with servers sending notifications containing severity - * levels, optional logger names, and arbitrary JSON-serializable data. - * - * @param level The severity levels. The minimum log level is set by the client. - * @param logger The logger that generated the message. - * @param data JSON-serializable logging data. - */ - @JsonIgnoreProperties(ignoreUnknown = true) - public record LoggingMessageNotification(// @formatter:off - @JsonProperty("level") LoggingLevel level, - @JsonProperty("logger") String logger, - @JsonProperty("data") String data) { - - public static Builder builder() { - return new Builder(); - } - - public static class Builder { - private LoggingLevel level = LoggingLevel.INFO; - private String logger = "server"; - private String data; - - public Builder level(LoggingLevel level) { - this.level = level; - return this; - } - - public Builder logger(String logger) { - this.logger = logger; - return this; - } - - public Builder data(String data) { - this.data = data; - return this; - } - - public LoggingMessageNotification build() { - return new LoggingMessageNotification(level, logger, data); - } - } - }// @formatter:on - - public enum LoggingLevel {// @formatter:off - @JsonProperty("debug") DEBUG(0), - @JsonProperty("info") INFO(1), - @JsonProperty("notice") NOTICE(2), - @JsonProperty("warning") WARNING(3), - @JsonProperty("error") ERROR(4), - @JsonProperty("critical") CRITICAL(5), - @JsonProperty("alert") ALERT(6), - @JsonProperty("emergency") EMERGENCY(7); - - private final int level; - - LoggingLevel(int level) { - this.level = level; - } - - public int level() { - return level; - } - - } // @formatter:on - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record SetLevelRequest(@JsonProperty("level") LoggingLevel level) { - } - - // --------------------------- - // Autocomplete - // --------------------------- - public sealed interface CompleteReference permits PromptReference, ResourceReference { - - String type(); - - String identifier(); - - } - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record PromptReference(// @formatter:off - @JsonProperty("type") String type, - @JsonProperty("name") String name) implements McpSchema.CompleteReference { - - public PromptReference(String name) { - this("ref/prompt", name); - } - - @Override - public String identifier() { - return name(); - } - }// @formatter:on - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record ResourceReference(// @formatter:off - @JsonProperty("type") String type, - @JsonProperty("uri") String uri) implements McpSchema.CompleteReference { - - public ResourceReference(String uri) { - this("ref/resource", uri); - } - - @Override - public String identifier() { - return uri(); - } - }// @formatter:on - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record CompleteRequest(// @formatter:off - @JsonProperty("ref") McpSchema.CompleteReference ref, - @JsonProperty("argument") CompleteArgument argument) implements Request { - - public record CompleteArgument( - @JsonProperty("name") String name, - @JsonProperty("value") String value) { - }// @formatter:on - } - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record CompleteResult(@JsonProperty("completion") CompleteCompletion completion) { // @formatter:off - - public record CompleteCompletion( - @JsonProperty("values") List values, - @JsonProperty("total") Integer total, - @JsonProperty("hasMore") Boolean hasMore) { - }// @formatter:on - } - - // --------------------------- - // Content Types - // --------------------------- - @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.PROPERTY, property = "type") - @JsonSubTypes({ @JsonSubTypes.Type(value = TextContent.class, name = "text"), - @JsonSubTypes.Type(value = ImageContent.class, name = "image"), - @JsonSubTypes.Type(value = AudioContent.class, name = "audio"), - @JsonSubTypes.Type(value = EmbeddedResource.class, name = "resource") }) - public sealed interface Content permits TextContent, ImageContent, AudioContent, EmbeddedResource { - - default String type() { - if (this instanceof TextContent) { - return "text"; - } - else if (this instanceof ImageContent) { - return "image"; - } - else if (this instanceof AudioContent) { - return "audio"; - } - else if (this instanceof EmbeddedResource) { - return "resource"; - } - throw new IllegalArgumentException("Unknown content type: " + this); - } - - } - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record TextContent( // @formatter:off - @JsonProperty("annotations") Annotations annotations, - @JsonProperty("text") String text) implements Annotated, Content { // @formatter:on - - public TextContent(String content) { - this(null, content); - } - - /** - * @deprecated Only exists for backwards-compatibility purposes. Use - * {@link TextContent#TextContent(Annotations, String)} instead. - */ - public TextContent(List audience, Double priority, String content) { - this(audience != null || priority != null ? new Annotations(audience, priority) : null, content); - } - - /** - * @deprecated Only exists for backwards-compatibility purposes. Use - * {@link TextContent#annotations()} instead. - */ - public List audience() { - return annotations == null ? null : annotations.audience(); - } - - /** - * @deprecated Only exists for backwards-compatibility purposes. Use - * {@link TextContent#annotations()} instead. - */ - public Double priority() { - return annotations == null ? null : annotations.priority(); - } - } - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record ImageContent( // @formatter:off - @JsonProperty("annotations") Annotations annotations, - @JsonProperty("data") String data, - @JsonProperty("mimeType") String mimeType) implements Annotated, Content { // @formatter:on - - /** - * @deprecated Only exists for backwards-compatibility purposes. Use - * {@link ImageContent#ImageContent(Annotations, String, String)} instead. - */ - public ImageContent(List audience, Double priority, String data, String mimeType) { - this(audience != null || priority != null ? new Annotations(audience, priority) : null, data, mimeType); - } - - /** - * @deprecated Only exists for backwards-compatibility purposes. Use - * {@link ImageContent#annotations()} instead. - */ - public List audience() { - return annotations == null ? null : annotations.audience(); - } - - /** - * @deprecated Only exists for backwards-compatibility purposes. Use - * {@link ImageContent#annotations()} instead. - */ - public Double priority() { - return annotations == null ? null : annotations.priority(); - } - } - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record AudioContent( // @formatter:off - @JsonProperty("annotations") Annotations annotations, - @JsonProperty("data") String data, - @JsonProperty("mimeType") String mimeType) implements Annotated, Content { // @formatter:on - } - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record EmbeddedResource( // @formatter:off - @JsonProperty("annotations") Annotations annotations, - @JsonProperty("resource") ResourceContents resource) implements Annotated, Content { // @formatter:on - - /** - * @deprecated Only exists for backwards-compatibility purposes. Use - * {@link EmbeddedResource#EmbeddedResource(Annotations, ResourceContents)} - * instead. - */ - public EmbeddedResource(List audience, Double priority, ResourceContents resource) { - this(audience != null || priority != null ? new Annotations(audience, priority) : null, resource); - } - - /** - * @deprecated Only exists for backwards-compatibility purposes. Use - * {@link EmbeddedResource#annotations()} instead. - */ - public List audience() { - return annotations == null ? null : annotations.audience(); - } - - /** - * @deprecated Only exists for backwards-compatibility purposes. Use - * {@link EmbeddedResource#annotations()} instead. - */ - public Double priority() { - return annotations == null ? null : annotations.priority(); - } - } - - // --------------------------- - // Roots - // --------------------------- - /** - * Represents a root directory or file that the server can operate on. - * - * @param uri The URI identifying the root. This *must* start with file:// for now. - * This restriction may be relaxed in future versions of the protocol to allow other - * URI schemes. - * @param name An optional name for the root. This can be used to provide a - * human-readable identifier for the root, which may be useful for display purposes or - * for referencing the root in other parts of the application. - */ - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record Root( // @formatter:off - @JsonProperty("uri") String uri, - @JsonProperty("name") String name) { - } // @formatter:on - - /** - * The client's response to a roots/list request from the server. This result contains - * an array of Root objects, each representing a root directory or file that the - * server can operate on. - * - * @param roots An array of Root objects, each representing a root directory or file - * that the server can operate on. - */ - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record ListRootsResult( // @formatter:off - @JsonProperty("roots") List roots) { - } // @formatter:on - -} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java deleted file mode 100644 index 8646c1b4c..000000000 --- a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.client; - -import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; -import io.modelcontextprotocol.spec.McpClientTransport; -import org.junit.jupiter.api.Timeout; -import org.testcontainers.containers.GenericContainer; -import org.testcontainers.containers.wait.strategy.Wait; - -/** - * Tests for the {@link McpSyncClient} with {@link HttpClientSseClientTransport}. - * - * @author Christian Tzolov - */ -@Timeout(15) // Giving extra time beyond the client timeout -class HttpSseMcpSyncClientTests extends AbstractMcpSyncClientTests { - - String host = "http://localhost:3003"; - - // Uses the https://github.com/tzolov/mcp-everything-server-docker-image - @SuppressWarnings("resource") - GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") - .withCommand("node dist/index.js sse") - .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) - .withExposedPorts(3001) - .waitingFor(Wait.forHttp("/").forStatusCode(404)); - - @Override - protected McpClientTransport createMcpTransport() { - return HttpClientSseClientTransport.builder(host).build(); - } - - @Override - protected void onStart() { - container.start(); - int port = container.getMappedPort(3001); - host = "http://" + container.getHost() + ":" + port; - } - - @Override - protected void onClose() { - container.stop(); - } - -} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientTests.java deleted file mode 100644 index 14ca82791..000000000 --- a/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientTests.java +++ /dev/null @@ -1,81 +0,0 @@ -package io.modelcontextprotocol.client; - -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.spec.McpClientTransport; -import io.modelcontextprotocol.spec.McpSchema; -import org.junit.jupiter.api.Test; -import reactor.core.publisher.Mono; - -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Function; - -import static org.assertj.core.api.Assertions.assertThatCode; - -class McpAsyncClientTests { - - public static final McpSchema.Implementation MOCK_SERVER_INFO = new McpSchema.Implementation("test-server", - "1.0.0"); - - public static final McpSchema.ServerCapabilities MOCK_SERVER_CAPABILITIES = McpSchema.ServerCapabilities.builder() - .build(); - - public static final McpSchema.InitializeResult MOCK_INIT_RESULT = new McpSchema.InitializeResult( - McpSchema.LATEST_PROTOCOL_VERSION, MOCK_SERVER_CAPABILITIES, MOCK_SERVER_INFO, "Test instructions"); - - private static final String CONTEXT_KEY = "context.key"; - - private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); - - @Test - void validateContextPassedToTransportConnect() { - McpClientTransport transport = new McpClientTransport() { - Function, Mono> handler; - - final AtomicReference contextValue = new AtomicReference<>(); - - @Override - public Mono connect( - Function, Mono> handler) { - return Mono.deferContextual(ctx -> { - this.handler = handler; - if (ctx.hasKey(CONTEXT_KEY)) { - this.contextValue.set(ctx.get(CONTEXT_KEY)); - } - return Mono.empty(); - }); - } - - @Override - public Mono closeGracefully() { - return Mono.empty(); - } - - @Override - public Mono sendMessage(McpSchema.JSONRPCMessage message) { - if (!"hello".equals(this.contextValue.get())) { - return Mono.error(new RuntimeException("Context value not propagated via #connect method")); - } - // We're only interested in handling the init request to provide an init - // response - if (!(message instanceof McpSchema.JSONRPCRequest)) { - return Mono.empty(); - } - McpSchema.JSONRPCResponse initResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, - ((McpSchema.JSONRPCRequest) message).id(), MOCK_INIT_RESULT, null); - return handler.apply(Mono.just(initResponse)).then(); - } - - @Override - public T unmarshalFrom(Object data, TypeReference typeRef) { - return OBJECT_MAPPER.convertValue(data, typeRef); - } - }; - - assertThatCode(() -> { - McpAsyncClient client = McpClient.async(transport).build(); - client.initialize().contextWrite(ctx -> ctx.put(CONTEXT_KEY, "hello")).block(); - }).doesNotThrowAnyException(); - } - -} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java deleted file mode 100644 index a101f0177..000000000 --- a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.client; - -import java.time.Duration; - -import io.modelcontextprotocol.client.transport.ServerParameters; -import io.modelcontextprotocol.client.transport.StdioClientTransport; -import io.modelcontextprotocol.spec.McpClientTransport; -import org.junit.jupiter.api.Timeout; - -/** - * Tests for the {@link McpAsyncClient} with {@link StdioClientTransport}. - * - * @author Christian Tzolov - * @author Dariusz Jędrzejczyk - */ -@Timeout(15) // Giving extra time beyond the client timeout -class StdioMcpAsyncClientTests extends AbstractMcpAsyncClientTests { - - @Override - protected McpClientTransport createMcpTransport() { - ServerParameters stdioParams; - if (System.getProperty("os.name").toLowerCase().contains("win")) { - stdioParams = ServerParameters.builder("cmd.exe") - .args("/c", "npx.cmd", "-y", "@modelcontextprotocol/server-everything", "stdio") - .build(); - } - else { - stdioParams = ServerParameters.builder("npx") - .args("-y", "@modelcontextprotocol/server-everything", "stdio") - .build(); - } - return new StdioClientTransport(stdioParams); - } - - protected Duration getInitializationTimeout() { - return Duration.ofSeconds(10); - } - -} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java deleted file mode 100644 index dd9f65895..000000000 --- a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java +++ /dev/null @@ -1,430 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.server; - -import java.time.Duration; -import java.util.List; - -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpSchema.CallToolResult; -import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; -import io.modelcontextprotocol.spec.McpSchema.Prompt; -import io.modelcontextprotocol.spec.McpSchema.PromptMessage; -import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; -import io.modelcontextprotocol.spec.McpSchema.Resource; -import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; -import io.modelcontextprotocol.spec.McpSchema.Tool; -import io.modelcontextprotocol.spec.McpServerTransportProvider; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import reactor.core.publisher.Mono; -import reactor.test.StepVerifier; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; -import static org.assertj.core.api.Assertions.assertThatThrownBy; - -/** - * Test suite for the {@link McpAsyncServer} that can be used with different - * {@link McpTransportProvider} implementations. - * - * @author Christian Tzolov - */ -public abstract class AbstractMcpAsyncServerTests { - - private static final String TEST_TOOL_NAME = "test-tool"; - - private static final String TEST_RESOURCE_URI = "test://resource"; - - private static final String TEST_PROMPT_NAME = "test-prompt"; - - abstract protected McpServerTransportProvider createMcpTransportProvider(); - - protected void onStart() { - } - - protected void onClose() { - } - - @BeforeEach - void setUp() { - } - - @AfterEach - void tearDown() { - onClose(); - } - - // --------------------------------------- - // Server Lifecycle Tests - // --------------------------------------- - - @Test - void testConstructorWithInvalidArguments() { - assertThatThrownBy(() -> McpServer.async((McpServerTransportProvider) null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("Transport provider must not be null"); - - assertThatThrownBy( - () -> McpServer.async(createMcpTransportProvider()).serverInfo((McpSchema.Implementation) null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("Server info must not be null"); - } - - @Test - void testGracefulShutdown() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); - - StepVerifier.create(mcpAsyncServer.closeGracefully()).verifyComplete(); - } - - @Test - void testImmediateClose() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); - - assertThatCode(() -> mcpAsyncServer.close()).doesNotThrowAnyException(); - } - - // --------------------------------------- - // Tools Tests - // --------------------------------------- - String emptyJsonSchema = """ - { - "$schema": "http://json-schema.org/draft-07/schema#", - "type": "object", - "properties": {} - } - """; - - @Test - void testAddTool() { - Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema); - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .build(); - - StepVerifier.create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolSpecification(newTool, - (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))))) - .verifyComplete(); - - assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - } - - @Test - void testAddDuplicateTool() { - Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(duplicateTool, (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))) - .build(); - - StepVerifier - .create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolSpecification(duplicateTool, - (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))))) - .verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Tool with name '" + TEST_TOOL_NAME + "' already exists"); - }); - - assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - } - - @Test - void testRemoveTool() { - Tool too = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(too, (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))) - .build(); - - StepVerifier.create(mcpAsyncServer.removeTool(TEST_TOOL_NAME)).verifyComplete(); - - assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - } - - @Test - void testRemoveNonexistentTool() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .build(); - - StepVerifier.create(mcpAsyncServer.removeTool("nonexistent-tool")).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class).hasMessage("Tool with name 'nonexistent-tool' not found"); - }); - - assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - } - - @Test - void testNotifyToolsListChanged() { - Tool too = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(too, (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))) - .build(); - - StepVerifier.create(mcpAsyncServer.notifyToolsListChanged()).verifyComplete(); - - assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - } - - // --------------------------------------- - // Resources Tests - // --------------------------------------- - - @Test - void testNotifyResourcesListChanged() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); - - StepVerifier.create(mcpAsyncServer.notifyResourcesListChanged()).verifyComplete(); - - assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - } - - @Test - void testNotifyResourcesUpdated() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); - - StepVerifier - .create(mcpAsyncServer - .notifyResourcesUpdated(new McpSchema.ResourcesUpdatedNotification(TEST_RESOURCE_URI))) - .verifyComplete(); - - assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - } - - @Test - void testAddResource() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().resources(true, false).build()) - .build(); - - Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", - null); - McpServerFeatures.AsyncResourceSpecification specification = new McpServerFeatures.AsyncResourceSpecification( - resource, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); - - StepVerifier.create(mcpAsyncServer.addResource(specification)).verifyComplete(); - - assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - } - - @Test - void testAddResourceWithNullSpecification() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().resources(true, false).build()) - .build(); - - StepVerifier.create(mcpAsyncServer.addResource((McpServerFeatures.AsyncResourceSpecification) null)) - .verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class).hasMessage("Resource must not be null"); - }); - - assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - } - - @Test - void testAddResourceWithoutCapability() { - // Create a server without resource capabilities - McpAsyncServer serverWithoutResources = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .build(); - - Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", - null); - McpServerFeatures.AsyncResourceSpecification specification = new McpServerFeatures.AsyncResourceSpecification( - resource, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); - - StepVerifier.create(serverWithoutResources.addResource(specification)).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Server must be configured with resource capabilities"); - }); - } - - @Test - void testRemoveResourceWithoutCapability() { - // Create a server without resource capabilities - McpAsyncServer serverWithoutResources = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .build(); - - StepVerifier.create(serverWithoutResources.removeResource(TEST_RESOURCE_URI)).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Server must be configured with resource capabilities"); - }); - } - - // --------------------------------------- - // Prompts Tests - // --------------------------------------- - - @Test - void testNotifyPromptsListChanged() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); - - StepVerifier.create(mcpAsyncServer.notifyPromptsListChanged()).verifyComplete(); - - assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - } - - @Test - void testAddPromptWithNullSpecification() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().prompts(false).build()) - .build(); - - StepVerifier.create(mcpAsyncServer.addPrompt((McpServerFeatures.AsyncPromptSpecification) null)) - .verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class).hasMessage("Prompt specification must not be null"); - }); - } - - @Test - void testAddPromptWithoutCapability() { - // Create a server without prompt capabilities - McpAsyncServer serverWithoutPrompts = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .build(); - - Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", List.of()); - McpServerFeatures.AsyncPromptSpecification specification = new McpServerFeatures.AsyncPromptSpecification( - prompt, (exchange, req) -> Mono.just(new GetPromptResult("Test prompt description", List - .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content")))))); - - StepVerifier.create(serverWithoutPrompts.addPrompt(specification)).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Server must be configured with prompt capabilities"); - }); - } - - @Test - void testRemovePromptWithoutCapability() { - // Create a server without prompt capabilities - McpAsyncServer serverWithoutPrompts = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .build(); - - StepVerifier.create(serverWithoutPrompts.removePrompt(TEST_PROMPT_NAME)).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Server must be configured with prompt capabilities"); - }); - } - - @Test - void testRemovePrompt() { - String TEST_PROMPT_NAME_TO_REMOVE = "TEST_PROMPT_NAME678"; - - Prompt prompt = new Prompt(TEST_PROMPT_NAME_TO_REMOVE, "Test Prompt", List.of()); - McpServerFeatures.AsyncPromptSpecification specification = new McpServerFeatures.AsyncPromptSpecification( - prompt, (exchange, req) -> Mono.just(new GetPromptResult("Test prompt description", List - .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content")))))); - - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().prompts(true).build()) - .prompts(specification) - .build(); - - StepVerifier.create(mcpAsyncServer.removePrompt(TEST_PROMPT_NAME_TO_REMOVE)).verifyComplete(); - - assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - } - - @Test - void testRemoveNonexistentPrompt() { - var mcpAsyncServer2 = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().prompts(true).build()) - .build(); - - StepVerifier.create(mcpAsyncServer2.removePrompt("nonexistent-prompt")).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Prompt with name 'nonexistent-prompt' not found"); - }); - - assertThatCode(() -> mcpAsyncServer2.closeGracefully().block(Duration.ofSeconds(10))) - .doesNotThrowAnyException(); - } - - // --------------------------------------- - // Roots Tests - // --------------------------------------- - - @Test - void testRootsChangeHandlers() { - // Test with single consumer - var rootsReceived = new McpSchema.Root[1]; - var consumerCalled = new boolean[1]; - - var singleConsumerServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .rootsChangeHandlers(List.of((exchange, roots) -> Mono.fromRunnable(() -> { - consumerCalled[0] = true; - if (!roots.isEmpty()) { - rootsReceived[0] = roots.get(0); - } - }))) - .build(); - - assertThat(singleConsumerServer).isNotNull(); - assertThatCode(() -> singleConsumerServer.closeGracefully().block(Duration.ofSeconds(10))) - .doesNotThrowAnyException(); - onClose(); - - // Test with multiple consumers - var consumer1Called = new boolean[1]; - var consumer2Called = new boolean[1]; - var rootsContent = new List[1]; - - var multipleConsumersServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .rootsChangeHandlers(List.of((exchange, roots) -> Mono.fromRunnable(() -> { - consumer1Called[0] = true; - rootsContent[0] = roots; - }), (exchange, roots) -> Mono.fromRunnable(() -> consumer2Called[0] = true))) - .build(); - - assertThat(multipleConsumersServer).isNotNull(); - assertThatCode(() -> multipleConsumersServer.closeGracefully().block(Duration.ofSeconds(10))) - .doesNotThrowAnyException(); - onClose(); - - // Test error handling - var errorHandlingServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .rootsChangeHandlers(List.of((exchange, roots) -> { - throw new RuntimeException("Test error"); - })) - .build(); - - assertThat(errorHandlingServer).isNotNull(); - assertThatCode(() -> errorHandlingServer.closeGracefully().block(Duration.ofSeconds(10))) - .doesNotThrowAnyException(); - onClose(); - - // Test without consumers - var noConsumersServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .build(); - - assertThat(noConsumersServer).isNotNull(); - assertThatCode(() -> noConsumersServer.closeGracefully().block(Duration.ofSeconds(10))) - .doesNotThrowAnyException(); - } - -} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java deleted file mode 100644 index 6cbb8632c..000000000 --- a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java +++ /dev/null @@ -1,401 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.server; - -import java.util.List; - -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpSchema.CallToolResult; -import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; -import io.modelcontextprotocol.spec.McpSchema.Prompt; -import io.modelcontextprotocol.spec.McpSchema.PromptMessage; -import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; -import io.modelcontextprotocol.spec.McpSchema.Resource; -import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; -import io.modelcontextprotocol.spec.McpSchema.Tool; -import io.modelcontextprotocol.spec.McpServerTransportProvider; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; -import static org.assertj.core.api.Assertions.assertThatThrownBy; - -/** - * Test suite for the {@link McpSyncServer} that can be used with different - * {@link McpTransportProvider} implementations. - * - * @author Christian Tzolov - */ -public abstract class AbstractMcpSyncServerTests { - - private static final String TEST_TOOL_NAME = "test-tool"; - - private static final String TEST_RESOURCE_URI = "test://resource"; - - private static final String TEST_PROMPT_NAME = "test-prompt"; - - abstract protected McpServerTransportProvider createMcpTransportProvider(); - - protected void onStart() { - } - - protected void onClose() { - } - - @BeforeEach - void setUp() { - // onStart(); - } - - @AfterEach - void tearDown() { - onClose(); - } - - // --------------------------------------- - // Server Lifecycle Tests - // --------------------------------------- - - @Test - void testConstructorWithInvalidArguments() { - assertThatThrownBy(() -> McpServer.sync((McpServerTransportProvider) null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("Transport provider must not be null"); - - assertThatThrownBy(() -> McpServer.sync(createMcpTransportProvider()).serverInfo(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("Server info must not be null"); - } - - @Test - void testGracefulShutdown() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - @Test - void testImmediateClose() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); - - assertThatCode(() -> mcpSyncServer.close()).doesNotThrowAnyException(); - } - - @Test - void testGetAsyncServer() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); - - assertThat(mcpSyncServer.getAsyncServer()).isNotNull(); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - // --------------------------------------- - // Tools Tests - // --------------------------------------- - - String emptyJsonSchema = """ - { - "$schema": "http://json-schema.org/draft-07/schema#", - "type": "object", - "properties": {} - } - """; - - @Test - void testAddTool() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .build(); - - Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema); - assertThatCode(() -> mcpSyncServer.addTool(new McpServerFeatures.SyncToolSpecification(newTool, - (exchange, args) -> new CallToolResult(List.of(), false)))) - .doesNotThrowAnyException(); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - @Test - void testAddDuplicateTool() { - Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(duplicateTool, (exchange, args) -> new CallToolResult(List.of(), false)) - .build(); - - assertThatThrownBy(() -> mcpSyncServer.addTool(new McpServerFeatures.SyncToolSpecification(duplicateTool, - (exchange, args) -> new CallToolResult(List.of(), false)))) - .isInstanceOf(McpError.class) - .hasMessage("Tool with name '" + TEST_TOOL_NAME + "' already exists"); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - @Test - void testRemoveTool() { - Tool tool = new McpSchema.Tool(TEST_TOOL_NAME, "Test tool", emptyJsonSchema); - - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(tool, (exchange, args) -> new CallToolResult(List.of(), false)) - .build(); - - assertThatCode(() -> mcpSyncServer.removeTool(TEST_TOOL_NAME)).doesNotThrowAnyException(); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - @Test - void testRemoveNonexistentTool() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .build(); - - assertThatThrownBy(() -> mcpSyncServer.removeTool("nonexistent-tool")).isInstanceOf(McpError.class) - .hasMessage("Tool with name 'nonexistent-tool' not found"); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - @Test - void testNotifyToolsListChanged() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); - - assertThatCode(() -> mcpSyncServer.notifyToolsListChanged()).doesNotThrowAnyException(); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - // --------------------------------------- - // Resources Tests - // --------------------------------------- - - @Test - void testNotifyResourcesListChanged() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); - - assertThatCode(() -> mcpSyncServer.notifyResourcesListChanged()).doesNotThrowAnyException(); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - @Test - void testNotifyResourcesUpdated() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); - - assertThatCode(() -> mcpSyncServer - .notifyResourcesUpdated(new McpSchema.ResourcesUpdatedNotification(TEST_RESOURCE_URI))) - .doesNotThrowAnyException(); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - @Test - void testAddResource() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().resources(true, false).build()) - .build(); - - Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", - null); - McpServerFeatures.SyncResourceSpecification specification = new McpServerFeatures.SyncResourceSpecification( - resource, (exchange, req) -> new ReadResourceResult(List.of())); - - assertThatCode(() -> mcpSyncServer.addResource(specification)).doesNotThrowAnyException(); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - @Test - void testAddResourceWithNullSpecification() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().resources(true, false).build()) - .build(); - - assertThatThrownBy(() -> mcpSyncServer.addResource((McpServerFeatures.SyncResourceSpecification) null)) - .isInstanceOf(McpError.class) - .hasMessage("Resource must not be null"); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - @Test - void testAddResourceWithoutCapability() { - var serverWithoutResources = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .build(); - - Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", - null); - McpServerFeatures.SyncResourceSpecification specification = new McpServerFeatures.SyncResourceSpecification( - resource, (exchange, req) -> new ReadResourceResult(List.of())); - - assertThatThrownBy(() -> serverWithoutResources.addResource(specification)).isInstanceOf(McpError.class) - .hasMessage("Server must be configured with resource capabilities"); - } - - @Test - void testRemoveResourceWithoutCapability() { - var serverWithoutResources = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .build(); - - assertThatThrownBy(() -> serverWithoutResources.removeResource(TEST_RESOURCE_URI)).isInstanceOf(McpError.class) - .hasMessage("Server must be configured with resource capabilities"); - } - - // --------------------------------------- - // Prompts Tests - // --------------------------------------- - - @Test - void testNotifyPromptsListChanged() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); - - assertThatCode(() -> mcpSyncServer.notifyPromptsListChanged()).doesNotThrowAnyException(); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - @Test - void testAddPromptWithNullSpecification() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().prompts(false).build()) - .build(); - - assertThatThrownBy(() -> mcpSyncServer.addPrompt((McpServerFeatures.SyncPromptSpecification) null)) - .isInstanceOf(McpError.class) - .hasMessage("Prompt specification must not be null"); - } - - @Test - void testAddPromptWithoutCapability() { - var serverWithoutPrompts = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .build(); - - Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", List.of()); - McpServerFeatures.SyncPromptSpecification specification = new McpServerFeatures.SyncPromptSpecification(prompt, - (exchange, req) -> new GetPromptResult("Test prompt description", List - .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content"))))); - - assertThatThrownBy(() -> serverWithoutPrompts.addPrompt(specification)).isInstanceOf(McpError.class) - .hasMessage("Server must be configured with prompt capabilities"); - } - - @Test - void testRemovePromptWithoutCapability() { - var serverWithoutPrompts = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .build(); - - assertThatThrownBy(() -> serverWithoutPrompts.removePrompt(TEST_PROMPT_NAME)).isInstanceOf(McpError.class) - .hasMessage("Server must be configured with prompt capabilities"); - } - - @Test - void testRemovePrompt() { - Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", List.of()); - McpServerFeatures.SyncPromptSpecification specification = new McpServerFeatures.SyncPromptSpecification(prompt, - (exchange, req) -> new GetPromptResult("Test prompt description", List - .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content"))))); - - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().prompts(true).build()) - .prompts(specification) - .build(); - - assertThatCode(() -> mcpSyncServer.removePrompt(TEST_PROMPT_NAME)).doesNotThrowAnyException(); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - @Test - void testRemoveNonexistentPrompt() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().prompts(true).build()) - .build(); - - assertThatThrownBy(() -> mcpSyncServer.removePrompt("nonexistent-prompt")).isInstanceOf(McpError.class) - .hasMessage("Prompt with name 'nonexistent-prompt' not found"); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - // --------------------------------------- - // Roots Tests - // --------------------------------------- - - @Test - void testRootsChangeHandlers() { - // Test with single consumer - var rootsReceived = new McpSchema.Root[1]; - var consumerCalled = new boolean[1]; - - var singleConsumerServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .rootsChangeHandlers(List.of((exchange, roots) -> { - consumerCalled[0] = true; - if (!roots.isEmpty()) { - rootsReceived[0] = roots.get(0); - } - })) - .build(); - - assertThat(singleConsumerServer).isNotNull(); - assertThatCode(() -> singleConsumerServer.closeGracefully()).doesNotThrowAnyException(); - onClose(); - - // Test with multiple consumers - var consumer1Called = new boolean[1]; - var consumer2Called = new boolean[1]; - var rootsContent = new List[1]; - - var multipleConsumersServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .rootsChangeHandlers(List.of((exchange, roots) -> { - consumer1Called[0] = true; - rootsContent[0] = roots; - }, (exchange, roots) -> consumer2Called[0] = true)) - .build(); - - assertThat(multipleConsumersServer).isNotNull(); - assertThatCode(() -> multipleConsumersServer.closeGracefully()).doesNotThrowAnyException(); - onClose(); - - // Test error handling - var errorHandlingServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .rootsChangeHandlers(List.of((exchange, roots) -> { - throw new RuntimeException("Test error"); - })) - .build(); - - assertThat(errorHandlingServer).isNotNull(); - assertThatCode(() -> errorHandlingServer.closeGracefully()).doesNotThrowAnyException(); - onClose(); - - // Test without consumers - var noConsumersServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); - - assertThat(noConsumersServer).isNotNull(); - assertThatCode(() -> noConsumersServer.closeGracefully()).doesNotThrowAnyException(); - } - -} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/BaseMcpAsyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/BaseMcpAsyncServerTests.java deleted file mode 100644 index 208bcb71b..000000000 --- a/mcp/src/test/java/io/modelcontextprotocol/server/BaseMcpAsyncServerTests.java +++ /dev/null @@ -1,5 +0,0 @@ -package io.modelcontextprotocol.server; - -public abstract class BaseMcpAsyncServerTests { - -} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java deleted file mode 100644 index dc9d1cfab..000000000 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java +++ /dev/null @@ -1,961 +0,0 @@ -/* - * Copyright 2024 - 2024 the original author or authors. - */ -package io.modelcontextprotocol.server.transport; - -import java.time.Duration; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Function; -import java.util.stream.Collectors; - -import com.fasterxml.jackson.databind.ObjectMapper; - -import io.modelcontextprotocol.client.McpClient; -import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; -import io.modelcontextprotocol.server.McpServer; -import io.modelcontextprotocol.server.McpServerFeatures; -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpSchema.CallToolResult; -import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; -import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; -import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; -import io.modelcontextprotocol.spec.McpSchema.ElicitRequest; -import io.modelcontextprotocol.spec.McpSchema.ElicitResult; -import io.modelcontextprotocol.spec.McpSchema.InitializeResult; -import io.modelcontextprotocol.spec.McpSchema.ModelPreferences; -import io.modelcontextprotocol.spec.McpSchema.Role; -import io.modelcontextprotocol.spec.McpSchema.Root; -import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; -import io.modelcontextprotocol.spec.McpSchema.Tool; -import org.apache.catalina.LifecycleException; -import org.apache.catalina.LifecycleState; -import org.apache.catalina.startup.Tomcat; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Disabled; -import org.junit.jupiter.api.Test; -import reactor.core.publisher.Mono; -import reactor.test.StepVerifier; - -import org.springframework.web.client.RestClient; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatExceptionOfType; -import static org.awaitility.Awaitility.await; -import static org.mockito.Mockito.mock; - -class HttpServletSseServerTransportProviderIntegrationTests { - - private static final int PORT = TomcatTestUtil.findAvailablePort(); - - private static final String CUSTOM_SSE_ENDPOINT = "/somePath/sse"; - - private static final String CUSTOM_MESSAGE_ENDPOINT = "/otherPath/mcp/message"; - - private HttpServletSseServerTransportProvider mcpServerTransportProvider; - - McpClient.SyncSpec clientBuilder; - - private Tomcat tomcat; - - @BeforeEach - public void before() { - // Create and configure the transport provider - mcpServerTransportProvider = HttpServletSseServerTransportProvider.builder() - .objectMapper(new ObjectMapper()) - .messageEndpoint(CUSTOM_MESSAGE_ENDPOINT) - .sseEndpoint(CUSTOM_SSE_ENDPOINT) - .build(); - - tomcat = TomcatTestUtil.createTomcatServer("", PORT, mcpServerTransportProvider); - try { - tomcat.start(); - assertThat(tomcat.getServer().getState()).isEqualTo(LifecycleState.STARTED); - } - catch (Exception e) { - throw new RuntimeException("Failed to start Tomcat", e); - } - - this.clientBuilder = McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT) - .sseEndpoint(CUSTOM_SSE_ENDPOINT) - .build()); - } - - @AfterEach - public void after() { - if (mcpServerTransportProvider != null) { - mcpServerTransportProvider.closeGracefully().block(); - } - if (tomcat != null) { - try { - tomcat.stop(); - tomcat.destroy(); - } - catch (LifecycleException e) { - throw new RuntimeException("Failed to stop Tomcat", e); - } - } - } - - // --------------------------------------- - // Sampling Tests - // --------------------------------------- - @Test - @Disabled - void testCreateMessageWithoutSamplingCapabilities() { - - McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { - - exchange.createMessage(mock(McpSchema.CreateMessageRequest.class)).block(); - - return Mono.just(mock(CallToolResult.class)); - }); - - var server = McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").tools(tool).build(); - - try ( - // Create client without sampling capabilities - var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) - .build()) { - - assertThat(client.initialize()).isNotNull(); - - try { - client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - } - catch (McpError e) { - assertThat(e).isInstanceOf(McpError.class) - .hasMessage("Client must be configured with sampling capabilities"); - } - } - server.close(); - } - - @Test - void testCreateMessageSuccess() { - - Function samplingHandler = request -> { - assertThat(request.messages()).hasSize(1); - assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); - - return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", - CreateMessageResult.StopReason.STOP_SEQUENCE); - }; - - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); - - McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { - - var createMessageRequest = McpSchema.CreateMessageRequest.builder() - .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, - new McpSchema.TextContent("Test message")))) - .modelPreferences(ModelPreferences.builder() - .hints(List.of()) - .costPriority(1.0) - .speedPriority(1.0) - .intelligencePriority(1.0) - .build()) - .build(); - - StepVerifier.create(exchange.createMessage(createMessageRequest)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.role()).isEqualTo(Role.USER); - assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); - assertThat(result.model()).isEqualTo("MockModelName"); - assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); - }).verifyComplete(); - - return Mono.just(callResponse); - }); - - var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .tools(tool) - .build(); - - try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().sampling().build()) - .sampling(samplingHandler) - .build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); - } - mcpServer.close(); - } - - @Test - void testCreateMessageWithRequestTimeoutSuccess() throws InterruptedException { - - // Client - - Function samplingHandler = request -> { - assertThat(request.messages()).hasSize(1); - assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); - try { - TimeUnit.SECONDS.sleep(2); - } - catch (InterruptedException e) { - throw new RuntimeException(e); - } - return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", - CreateMessageResult.StopReason.STOP_SEQUENCE); - }; - - var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().sampling().build()) - .sampling(samplingHandler) - .build(); - - // Server - - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); - - McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { - - var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() - .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, - new McpSchema.TextContent("Test message")))) - .modelPreferences(ModelPreferences.builder() - .hints(List.of()) - .costPriority(1.0) - .speedPriority(1.0) - .intelligencePriority(1.0) - .build()) - .build(); - - StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.role()).isEqualTo(Role.USER); - assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); - assertThat(result.model()).isEqualTo("MockModelName"); - assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); - }).verifyComplete(); - - return Mono.just(callResponse); - }); - - var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .requestTimeout(Duration.ofSeconds(3)) - .tools(tool) - .build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); - - mcpClient.close(); - mcpServer.close(); - } - - @Test - void testCreateMessageWithRequestTimeoutFail() throws InterruptedException { - - // Client - - Function samplingHandler = request -> { - assertThat(request.messages()).hasSize(1); - assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); - try { - TimeUnit.SECONDS.sleep(2); - } - catch (InterruptedException e) { - throw new RuntimeException(e); - } - return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", - CreateMessageResult.StopReason.STOP_SEQUENCE); - }; - - var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().sampling().build()) - .sampling(samplingHandler) - .build(); - - // Server - - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); - - McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { - - var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() - .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, - new McpSchema.TextContent("Test message")))) - .modelPreferences(ModelPreferences.builder() - .hints(List.of()) - .costPriority(1.0) - .speedPriority(1.0) - .intelligencePriority(1.0) - .build()) - .build(); - - StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.role()).isEqualTo(Role.USER); - assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); - assertThat(result.model()).isEqualTo("MockModelName"); - assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); - }).verifyComplete(); - - return Mono.just(callResponse); - }); - - var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .requestTimeout(Duration.ofSeconds(1)) - .tools(tool) - .build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThatExceptionOfType(McpError.class).isThrownBy(() -> { - mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - }).withMessageContaining("Timeout"); - - mcpClient.close(); - mcpServer.close(); - } - - // --------------------------------------- - // Elicitation Tests - // --------------------------------------- - @Test - @Disabled - void testCreateElicitationWithoutElicitationCapabilities() { - - McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { - - exchange.createElicitation(mock(ElicitRequest.class)).block(); - - return Mono.just(mock(CallToolResult.class)); - }); - - var server = McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").tools(tool).build(); - - try ( - // Create client without elicitation capabilities - var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")).build()) { - - assertThat(client.initialize()).isNotNull(); - - try { - client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - } - catch (McpError e) { - assertThat(e).isInstanceOf(McpError.class) - .hasMessage("Client must be configured with elicitation capabilities"); - } - } - server.closeGracefully().block(); - } - - @Test - void testCreateElicitationSuccess() { - - Function elicitationHandler = request -> { - assertThat(request.message()).isNotEmpty(); - assertThat(request.requestedSchema()).isNotNull(); - - return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); - }; - - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); - - McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { - - var elicitationRequest = ElicitRequest.builder() - .message("Test message") - .requestedSchema( - Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) - .build(); - - StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.action()).isEqualTo(ElicitResult.Action.ACCEPT); - assertThat(result.content().get("message")).isEqualTo("Test message"); - }).verifyComplete(); - - return Mono.just(callResponse); - }); - - var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .tools(tool) - .build(); - - try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().elicitation().build()) - .elicitation(elicitationHandler) - .build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); - } - mcpServer.closeGracefully().block(); - } - - @Test - void testCreateElicitationWithRequestTimeoutSuccess() { - - // Client - - Function elicitationHandler = request -> { - assertThat(request.message()).isNotEmpty(); - assertThat(request.requestedSchema()).isNotNull(); - try { - TimeUnit.SECONDS.sleep(2); - } - catch (InterruptedException e) { - throw new RuntimeException(e); - } - return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); - }; - - var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().elicitation().build()) - .elicitation(elicitationHandler) - .build(); - - // Server - - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); - - McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { - - var elicitationRequest = ElicitRequest.builder() - .message("Test message") - .requestedSchema( - Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) - .build(); - - StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.action()).isEqualTo(ElicitResult.Action.ACCEPT); - assertThat(result.content().get("message")).isEqualTo("Test message"); - }).verifyComplete(); - - return Mono.just(callResponse); - }); - - var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .requestTimeout(Duration.ofSeconds(3)) - .tools(tool) - .build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); - - mcpClient.closeGracefully(); - mcpServer.closeGracefully().block(); - } - - @Test - void testCreateElicitationWithRequestTimeoutFail() { - - // Client - - Function elicitationHandler = request -> { - assertThat(request.message()).isNotEmpty(); - assertThat(request.requestedSchema()).isNotNull(); - try { - TimeUnit.SECONDS.sleep(2); - } - catch (InterruptedException e) { - throw new RuntimeException(e); - } - return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); - }; - - var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().elicitation().build()) - .elicitation(elicitationHandler) - .build(); - - // Server - - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); - - McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { - - var elicitationRequest = ElicitRequest.builder() - .message("Test message") - .requestedSchema( - Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) - .build(); - - StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.action()).isEqualTo(ElicitResult.Action.ACCEPT); - assertThat(result.content().get("message")).isEqualTo("Test message"); - }).verifyComplete(); - - return Mono.just(callResponse); - }); - - var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .requestTimeout(Duration.ofSeconds(1)) - .tools(tool) - .build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThatExceptionOfType(McpError.class).isThrownBy(() -> { - mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - }).withMessageContaining("Timeout"); - - mcpClient.closeGracefully(); - mcpServer.closeGracefully().block(); - } - - // --------------------------------------- - // Roots Tests - // --------------------------------------- - @Test - void testRootsSuccess() { - List roots = List.of(new Root("uri1://", "root1"), new Root("uri2://", "root2")); - - AtomicReference> rootsRef = new AtomicReference<>(); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) - .build(); - - try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) - .roots(roots) - .build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThat(rootsRef.get()).isNull(); - - mcpClient.rootsListChangedNotification(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(roots); - }); - - // Remove a root - mcpClient.removeRoot(roots.get(0).uri()); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(roots.get(1))); - }); - - // Add a new root - var root3 = new Root("uri3://", "root3"); - mcpClient.addRoot(root3); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(roots.get(1), root3)); - }); - - mcpServer.close(); - } - } - - @Test - void testRootsWithoutCapability() { - - McpServerFeatures.SyncToolSpecification tool = new McpServerFeatures.SyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { - - exchange.listRoots(); // try to list roots - - return mock(CallToolResult.class); - }); - - var mcpServer = McpServer.sync(mcpServerTransportProvider).rootsChangeHandler((exchange, rootsUpdate) -> { - }).tools(tool).build(); - - try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()).build()) { - - assertThat(mcpClient.initialize()).isNotNull(); - - // Attempt to list roots should fail - try { - mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - } - catch (McpError e) { - assertThat(e).isInstanceOf(McpError.class).hasMessage("Roots not supported"); - } - } - - mcpServer.close(); - } - - @Test - void testRootsNotificationWithEmptyRootsList() { - AtomicReference> rootsRef = new AtomicReference<>(); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) - .build(); - - try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) - .roots(List.of()) // Empty roots list - .build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - mcpClient.rootsListChangedNotification(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).isEmpty(); - }); - } - - mcpServer.close(); - } - - @Test - void testRootsWithMultipleHandlers() { - List roots = List.of(new Root("uri1://", "root1")); - - AtomicReference> rootsRef1 = new AtomicReference<>(); - AtomicReference> rootsRef2 = new AtomicReference<>(); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef1.set(rootsUpdate)) - .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef2.set(rootsUpdate)) - .build(); - - try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) - .roots(roots) - .build()) { - - assertThat(mcpClient.initialize()).isNotNull(); - - mcpClient.rootsListChangedNotification(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef1.get()).containsAll(roots); - assertThat(rootsRef2.get()).containsAll(roots); - }); - } - - mcpServer.close(); - } - - @Test - void testRootsServerCloseWithActiveSubscription() { - List roots = List.of(new Root("uri1://", "root1")); - - AtomicReference> rootsRef = new AtomicReference<>(); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) - .build(); - - try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) - .roots(roots) - .build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - mcpClient.rootsListChangedNotification(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(roots); - }); - } - - mcpServer.close(); - } - - // --------------------------------------- - // Tools Tests - // --------------------------------------- - - String emptyJsonSchema = """ - { - "$schema": "http://json-schema.org/draft-07/schema#", - "type": "object", - "properties": {} - } - """; - - @Test - void testToolCallSuccess() { - - var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); - McpServerFeatures.SyncToolSpecification tool1 = new McpServerFeatures.SyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { - // perform a blocking call to a remote service - String response = RestClient.create() - .get() - .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md") - .retrieve() - .body(String.class); - assertThat(response).isNotBlank(); - return callResponse; - }); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool1) - .build(); - - try (var mcpClient = clientBuilder.build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); - - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); - } - - mcpServer.close(); - } - - @Test - void testToolListChangeHandlingSuccess() { - - var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); - McpServerFeatures.SyncToolSpecification tool1 = new McpServerFeatures.SyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { - // perform a blocking call to a remote service - String response = RestClient.create() - .get() - .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md") - .retrieve() - .body(String.class); - assertThat(response).isNotBlank(); - return callResponse; - }); - - AtomicReference> rootsRef = new AtomicReference<>(); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool1) - .build(); - - try (var mcpClient = clientBuilder.toolsChangeConsumer(toolsUpdate -> { - // perform a blocking call to a remote service - String response = RestClient.create() - .get() - .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md") - .retrieve() - .body(String.class); - assertThat(response).isNotBlank(); - rootsRef.set(toolsUpdate); - }).build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThat(rootsRef.get()).isNull(); - - assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); - - mcpServer.notifyToolsListChanged(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(tool1.tool())); - }); - - // Remove a tool - mcpServer.removeTool("tool1"); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).isEmpty(); - }); - - // Add a new tool - McpServerFeatures.SyncToolSpecification tool2 = new McpServerFeatures.SyncToolSpecification( - new McpSchema.Tool("tool2", "tool2 description", emptyJsonSchema), - (exchange, request) -> callResponse); - - mcpServer.addTool(tool2); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(tool2.tool())); - }); - } - - mcpServer.close(); - } - - @Test - void testInitialize() { - var mcpServer = McpServer.sync(mcpServerTransportProvider).build(); - - try (var mcpClient = clientBuilder.build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - } - - mcpServer.close(); - } - - // --------------------------------------- - // Logging Tests - // --------------------------------------- - @Test - void testLoggingNotification() { - // Create a list to store received logging notifications - List receivedNotifications = new ArrayList<>(); - - // Create server with a tool that sends logging notifications - McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( - new McpSchema.Tool("logging-test", "Test logging notifications", emptyJsonSchema), - (exchange, request) -> { - - // Create and send notifications with different levels - - // This should be filtered out (DEBUG < NOTICE) - exchange - .loggingNotification(McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.DEBUG) - .logger("test-logger") - .data("Debug message") - .build()) - .block(); - - // This should be sent (NOTICE >= NOTICE) - exchange - .loggingNotification(McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.NOTICE) - .logger("test-logger") - .data("Notice message") - .build()) - .block(); - - // This should be sent (ERROR > NOTICE) - exchange - .loggingNotification(McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.ERROR) - .logger("test-logger") - .data("Error message") - .build()) - .block(); - - // This should be filtered out (INFO < NOTICE) - exchange - .loggingNotification(McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.INFO) - .logger("test-logger") - .data("Another info message") - .build()) - .block(); - - // This should be sent (ERROR >= NOTICE) - exchange - .loggingNotification(McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.ERROR) - .logger("test-logger") - .data("Another error message") - .build()) - .block(); - - return Mono.just(new CallToolResult("Logging test completed", false)); - }); - - var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().logging().tools(true).build()) - .tools(tool) - .build(); - try ( - // Create client with logging notification handler - var mcpClient = clientBuilder.loggingConsumer(notification -> { - receivedNotifications.add(notification); - }).build()) { - - // Initialize client - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - // Set minimum logging level to NOTICE - mcpClient.setLoggingLevel(McpSchema.LoggingLevel.NOTICE); - - // Call the tool that sends logging notifications - CallToolResult result = mcpClient.callTool(new McpSchema.CallToolRequest("logging-test", Map.of())); - assertThat(result).isNotNull(); - assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Logging test completed"); - - // Wait for notifications to be processed - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - - System.out.println("Received notifications: " + receivedNotifications); - - // Should have received 3 notifications (1 NOTICE and 2 ERROR) - assertThat(receivedNotifications).hasSize(3); - - Map notificationMap = receivedNotifications.stream() - .collect(Collectors.toMap(n -> n.data(), n -> n)); - - // First notification should be NOTICE level - assertThat(notificationMap.get("Notice message").level()).isEqualTo(McpSchema.LoggingLevel.NOTICE); - assertThat(notificationMap.get("Notice message").logger()).isEqualTo("test-logger"); - assertThat(notificationMap.get("Notice message").data()).isEqualTo("Notice message"); - - // Second notification should be ERROR level - assertThat(notificationMap.get("Error message").level()).isEqualTo(McpSchema.LoggingLevel.ERROR); - assertThat(notificationMap.get("Error message").logger()).isEqualTo("test-logger"); - assertThat(notificationMap.get("Error message").data()).isEqualTo("Error message"); - - // Third notification should be ERROR level - assertThat(notificationMap.get("Another error message").level()) - .isEqualTo(McpSchema.LoggingLevel.ERROR); - assertThat(notificationMap.get("Another error message").logger()).isEqualTo("test-logger"); - assertThat(notificationMap.get("Another error message").data()).isEqualTo("Another error message"); - }); - } - mcpServer.close(); - } - -} diff --git a/mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java b/mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java deleted file mode 100644 index f72be43e0..000000000 --- a/mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java +++ /dev/null @@ -1,198 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.spec; - -import java.time.Duration; -import java.util.Map; - -import com.fasterxml.jackson.core.type.TypeReference; -import io.modelcontextprotocol.MockMcpClientTransport; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import reactor.core.publisher.Mono; -import reactor.core.publisher.Sinks; -import reactor.test.StepVerifier; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; - -/** - * Test suite for {@link McpClientSession} that verifies its JSON-RPC message handling, - * request-response correlation, and notification processing. - * - * @author Christian Tzolov - */ -class McpClientSessionTests { - - private static final Logger logger = LoggerFactory.getLogger(McpClientSessionTests.class); - - private static final Duration TIMEOUT = Duration.ofSeconds(5); - - private static final String TEST_METHOD = "test.method"; - - private static final String TEST_NOTIFICATION = "test.notification"; - - private static final String ECHO_METHOD = "echo"; - - private McpClientSession session; - - private MockMcpClientTransport transport; - - @BeforeEach - void setUp() { - transport = new MockMcpClientTransport(); - session = new McpClientSession(TIMEOUT, transport, Map.of(), - Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> logger.info("Status update: " + params)))); - } - - @AfterEach - void tearDown() { - if (session != null) { - session.close(); - } - } - - @Test - void testConstructorWithInvalidArguments() { - assertThatThrownBy(() -> new McpClientSession(null, transport, Map.of(), Map.of())) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("The requestTimeout can not be null"); - - assertThatThrownBy(() -> new McpClientSession(TIMEOUT, null, Map.of(), Map.of())) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("transport can not be null"); - } - - TypeReference responseType = new TypeReference<>() { - }; - - @Test - void testSendRequest() { - String testParam = "test parameter"; - String responseData = "test response"; - - // Create a Mono that will emit the response after the request is sent - Mono responseMono = session.sendRequest(TEST_METHOD, testParam, responseType); - // Verify response handling - StepVerifier.create(responseMono).then(() -> { - McpSchema.JSONRPCRequest request = transport.getLastSentMessageAsRequest(); - transport.simulateIncomingMessage( - new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), responseData, null)); - }).consumeNextWith(response -> { - // Verify the request was sent - McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessageAsRequest(); - assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCRequest.class); - McpSchema.JSONRPCRequest request = (McpSchema.JSONRPCRequest) sentMessage; - assertThat(request.method()).isEqualTo(TEST_METHOD); - assertThat(request.params()).isEqualTo(testParam); - assertThat(response).isEqualTo(responseData); - }).verifyComplete(); - } - - @Test - void testSendRequestWithError() { - Mono responseMono = session.sendRequest(TEST_METHOD, "test", responseType); - - // Verify error handling - StepVerifier.create(responseMono).then(() -> { - McpSchema.JSONRPCRequest request = transport.getLastSentMessageAsRequest(); - // Simulate error response - McpSchema.JSONRPCResponse.JSONRPCError error = new McpSchema.JSONRPCResponse.JSONRPCError( - McpSchema.ErrorCodes.METHOD_NOT_FOUND, "Method not found", null); - transport.simulateIncomingMessage( - new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, error)); - }).expectError(McpError.class).verify(); - } - - @Test - void testRequestTimeout() { - Mono responseMono = session.sendRequest(TEST_METHOD, "test", responseType); - - // Verify timeout - StepVerifier.create(responseMono) - .expectError(java.util.concurrent.TimeoutException.class) - .verify(TIMEOUT.plusSeconds(1)); - } - - @Test - void testSendNotification() { - Map params = Map.of("key", "value"); - Mono notificationMono = session.sendNotification(TEST_NOTIFICATION, params); - - // Verify notification was sent - StepVerifier.create(notificationMono).consumeSubscriptionWith(response -> { - McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessage(); - assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCNotification.class); - McpSchema.JSONRPCNotification notification = (McpSchema.JSONRPCNotification) sentMessage; - assertThat(notification.method()).isEqualTo(TEST_NOTIFICATION); - assertThat(notification.params()).isEqualTo(params); - }).verifyComplete(); - } - - @Test - void testRequestHandling() { - String echoMessage = "Hello MCP!"; - Map> requestHandlers = Map.of(ECHO_METHOD, - params -> Mono.just(params)); - transport = new MockMcpClientTransport(); - session = new McpClientSession(TIMEOUT, transport, requestHandlers, Map.of()); - - // Simulate incoming request - McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, ECHO_METHOD, - "test-id", echoMessage); - transport.simulateIncomingMessage(request); - - // Verify response - McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessage(); - assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class); - McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; - assertThat(response.result()).isEqualTo(echoMessage); - assertThat(response.error()).isNull(); - } - - @Test - void testNotificationHandling() { - Sinks.One receivedParams = Sinks.one(); - - transport = new MockMcpClientTransport(); - session = new McpClientSession(TIMEOUT, transport, Map.of(), - Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> receivedParams.tryEmitValue(params)))); - - // Simulate incoming notification from the server - Map notificationParams = Map.of("status", "ready"); - - McpSchema.JSONRPCNotification notification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, - TEST_NOTIFICATION, notificationParams); - - transport.simulateIncomingMessage(notification); - - // Verify handler was called - assertThat(receivedParams.asMono().block(Duration.ofSeconds(1))).isEqualTo(notificationParams); - } - - @Test - void testUnknownMethodHandling() { - // Simulate incoming request for unknown method - McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, "unknown.method", - "test-id", null); - transport.simulateIncomingMessage(request); - - // Verify error response - McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessage(); - assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class); - McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; - assertThat(response.error()).isNotNull(); - assertThat(response.error().code()).isEqualTo(McpSchema.ErrorCodes.METHOD_NOT_FOUND); - } - - @Test - void testGracefulShutdown() { - StepVerifier.create(session.closeGracefully()).verifyComplete(); - } - -} diff --git a/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java b/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java deleted file mode 100644 index df8176a4b..000000000 --- a/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java +++ /dev/null @@ -1,1000 +0,0 @@ -/* -* Copyright 2025 - 2025 the original author or authors. -*/ -package io.modelcontextprotocol.spec; - -import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson; -import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; - -import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import org.junit.jupiter.api.Test; - -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.databind.exc.InvalidTypeIdException; - -import io.modelcontextprotocol.spec.McpSchema.TextResourceContents; -import net.javacrumbs.jsonunit.core.Option; - -/** - * @author Christian Tzolov - */ -public class McpSchemaTests { - - ObjectMapper mapper = new ObjectMapper(); - - // Content Types Tests - - @Test - void testTextContent() throws Exception { - McpSchema.TextContent test = new McpSchema.TextContent("XXX"); - String value = mapper.writeValueAsString(test); - - assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo(json(""" - {"type":"text","text":"XXX"}""")); - } - - @Test - void testTextContentDeserialization() throws Exception { - McpSchema.TextContent textContent = mapper.readValue(""" - {"type":"text","text":"XXX"}""", McpSchema.TextContent.class); - - assertThat(textContent).isNotNull(); - assertThat(textContent.type()).isEqualTo("text"); - assertThat(textContent.text()).isEqualTo("XXX"); - } - - @Test - void testContentDeserializationWrongType() throws Exception { - - assertThatThrownBy(() -> mapper.readValue(""" - {"type":"WRONG","text":"XXX"}""", McpSchema.TextContent.class)) - .isInstanceOf(InvalidTypeIdException.class) - .hasMessageContaining( - "Could not resolve type id 'WRONG' as a subtype of `io.modelcontextprotocol.spec.McpSchema$TextContent`: known type ids = [audio, image, resource, text]"); - } - - @Test - void testImageContent() throws Exception { - McpSchema.ImageContent test = new McpSchema.ImageContent(null, null, "base64encodeddata", "image/png"); - String value = mapper.writeValueAsString(test); - - assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo(json(""" - {"type":"image","data":"base64encodeddata","mimeType":"image/png"}""")); - } - - @Test - void testImageContentDeserialization() throws Exception { - McpSchema.ImageContent imageContent = mapper.readValue(""" - {"type":"image","data":"base64encodeddata","mimeType":"image/png"}""", McpSchema.ImageContent.class); - assertThat(imageContent).isNotNull(); - assertThat(imageContent.type()).isEqualTo("image"); - assertThat(imageContent.data()).isEqualTo("base64encodeddata"); - assertThat(imageContent.mimeType()).isEqualTo("image/png"); - } - - @Test - void testAudioContent() throws Exception { - McpSchema.AudioContent audioContent = new McpSchema.AudioContent(null, "base64encodeddata", "audio/wav"); - String value = mapper.writeValueAsString(audioContent); - - assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo(json(""" - {"type":"audio","data":"base64encodeddata","mimeType":"audio/wav"}""")); - } - - @Test - void testAudioContentDeserialization() throws Exception { - McpSchema.AudioContent audioContent = mapper.readValue(""" - {"type":"audio","data":"base64encodeddata","mimeType":"audio/wav"}""", McpSchema.AudioContent.class); - assertThat(audioContent).isNotNull(); - assertThat(audioContent.type()).isEqualTo("audio"); - assertThat(audioContent.data()).isEqualTo("base64encodeddata"); - assertThat(audioContent.mimeType()).isEqualTo("audio/wav"); - } - - @Test - void testEmbeddedResource() throws Exception { - McpSchema.TextResourceContents resourceContents = new McpSchema.TextResourceContents("resource://test", - "text/plain", "Sample resource content"); - - McpSchema.EmbeddedResource test = new McpSchema.EmbeddedResource(null, null, resourceContents); - - String value = mapper.writeValueAsString(test); - assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo( - json(""" - {"type":"resource","resource":{"uri":"resource://test","mimeType":"text/plain","text":"Sample resource content"}}""")); - } - - @Test - void testEmbeddedResourceDeserialization() throws Exception { - McpSchema.EmbeddedResource embeddedResource = mapper.readValue( - """ - {"type":"resource","resource":{"uri":"resource://test","mimeType":"text/plain","text":"Sample resource content"}}""", - McpSchema.EmbeddedResource.class); - assertThat(embeddedResource).isNotNull(); - assertThat(embeddedResource.type()).isEqualTo("resource"); - assertThat(embeddedResource.resource()).isNotNull(); - assertThat(embeddedResource.resource().uri()).isEqualTo("resource://test"); - assertThat(embeddedResource.resource().mimeType()).isEqualTo("text/plain"); - assertThat(((TextResourceContents) embeddedResource.resource()).text()).isEqualTo("Sample resource content"); - } - - @Test - void testEmbeddedResourceWithBlobContents() throws Exception { - McpSchema.BlobResourceContents resourceContents = new McpSchema.BlobResourceContents("resource://test", - "application/octet-stream", "base64encodedblob"); - - McpSchema.EmbeddedResource test = new McpSchema.EmbeddedResource(null, null, resourceContents); - - String value = mapper.writeValueAsString(test); - assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo( - json(""" - {"type":"resource","resource":{"uri":"resource://test","mimeType":"application/octet-stream","blob":"base64encodedblob"}}""")); - } - - @Test - void testEmbeddedResourceWithBlobContentsDeserialization() throws Exception { - McpSchema.EmbeddedResource embeddedResource = mapper.readValue( - """ - {"type":"resource","resource":{"uri":"resource://test","mimeType":"application/octet-stream","blob":"base64encodedblob"}}""", - McpSchema.EmbeddedResource.class); - assertThat(embeddedResource).isNotNull(); - assertThat(embeddedResource.type()).isEqualTo("resource"); - assertThat(embeddedResource.resource()).isNotNull(); - assertThat(embeddedResource.resource().uri()).isEqualTo("resource://test"); - assertThat(embeddedResource.resource().mimeType()).isEqualTo("application/octet-stream"); - assertThat(((McpSchema.BlobResourceContents) embeddedResource.resource()).blob()) - .isEqualTo("base64encodedblob"); - } - - // JSON-RPC Message Types Tests - - @Test - void testJSONRPCRequest() throws Exception { - Map params = new HashMap<>(); - params.put("key", "value"); - - McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, "method_name", 1, - params); - - String value = mapper.writeValueAsString(request); - assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo(json(""" - {"jsonrpc":"2.0","method":"method_name","id":1,"params":{"key":"value"}}""")); - } - - @Test - void testJSONRPCNotification() throws Exception { - Map params = new HashMap<>(); - params.put("key", "value"); - - McpSchema.JSONRPCNotification notification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, - "notification_method", params); - - String value = mapper.writeValueAsString(notification); - assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo(json(""" - {"jsonrpc":"2.0","method":"notification_method","params":{"key":"value"}}""")); - } - - @Test - void testJSONRPCResponse() throws Exception { - Map result = new HashMap<>(); - result.put("result_key", "result_value"); - - McpSchema.JSONRPCResponse response = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, 1, result, null); - - String value = mapper.writeValueAsString(response); - assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo(json(""" - {"jsonrpc":"2.0","id":1,"result":{"result_key":"result_value"}}""")); - } - - @Test - void testJSONRPCResponseWithError() throws Exception { - McpSchema.JSONRPCResponse.JSONRPCError error = new McpSchema.JSONRPCResponse.JSONRPCError( - McpSchema.ErrorCodes.INVALID_REQUEST, "Invalid request", null); - - McpSchema.JSONRPCResponse response = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, 1, null, error); - - String value = mapper.writeValueAsString(response); - assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo(json(""" - {"jsonrpc":"2.0","id":1,"error":{"code":-32600,"message":"Invalid request"}}""")); - } - - // Initialization Tests - - @Test - void testInitializeRequest() throws Exception { - McpSchema.ClientCapabilities capabilities = McpSchema.ClientCapabilities.builder() - .roots(true) - .sampling() - .build(); - - McpSchema.Implementation clientInfo = new McpSchema.Implementation("test-client", "1.0.0"); - - McpSchema.InitializeRequest request = new McpSchema.InitializeRequest("2024-11-05", capabilities, clientInfo); - - String value = mapper.writeValueAsString(request); - assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo( - json(""" - {"protocolVersion":"2024-11-05","capabilities":{"roots":{"listChanged":true},"sampling":{}},"clientInfo":{"name":"test-client","version":"1.0.0"}}""")); - } - - @Test - void testInitializeResult() throws Exception { - McpSchema.ServerCapabilities capabilities = McpSchema.ServerCapabilities.builder() - .logging() - .prompts(true) - .resources(true, true) - .tools(true) - .build(); - - McpSchema.Implementation serverInfo = new McpSchema.Implementation("test-server", "1.0.0"); - - McpSchema.InitializeResult result = new McpSchema.InitializeResult("2024-11-05", capabilities, serverInfo, - "Server initialized successfully"); - - String value = mapper.writeValueAsString(result); - assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo( - json(""" - {"protocolVersion":"2024-11-05","capabilities":{"logging":{},"prompts":{"listChanged":true},"resources":{"subscribe":true,"listChanged":true},"tools":{"listChanged":true}},"serverInfo":{"name":"test-server","version":"1.0.0"},"instructions":"Server initialized successfully"}""")); - } - - // Resource Tests - - @Test - void testResource() throws Exception { - McpSchema.Annotations annotations = new McpSchema.Annotations( - Arrays.asList(McpSchema.Role.USER, McpSchema.Role.ASSISTANT), 0.8); - - McpSchema.Resource resource = new McpSchema.Resource("resource://test", "Test Resource", "A test resource", - "text/plain", annotations); - - String value = mapper.writeValueAsString(resource); - assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo( - json(""" - {"uri":"resource://test","name":"Test Resource","description":"A test resource","mimeType":"text/plain","annotations":{"audience":["user","assistant"],"priority":0.8}}""")); - } - - @Test - void testResourceBuilder() throws Exception { - McpSchema.Annotations annotations = new McpSchema.Annotations( - Arrays.asList(McpSchema.Role.USER, McpSchema.Role.ASSISTANT), 0.8); - - McpSchema.Resource resource = McpSchema.Resource.builder() - .uri("resource://test") - .name("Test Resource") - .description("A test resource") - .mimeType("text/plain") - .size(256L) - .annotations(annotations) - .build(); - - String value = mapper.writeValueAsString(resource); - assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo( - json(""" - {"uri":"resource://test","name":"Test Resource","description":"A test resource","mimeType":"text/plain","size":256,"annotations":{"audience":["user","assistant"],"priority":0.8}}""")); - } - - @Test - void testResourceBuilderUriRequired() { - McpSchema.Annotations annotations = new McpSchema.Annotations( - Arrays.asList(McpSchema.Role.USER, McpSchema.Role.ASSISTANT), 0.8); - - McpSchema.Resource.Builder resourceBuilder = McpSchema.Resource.builder() - .name("Test Resource") - .description("A test resource") - .mimeType("text/plain") - .size(256L) - .annotations(annotations); - - assertThatThrownBy(resourceBuilder::build).isInstanceOf(java.lang.IllegalArgumentException.class); - } - - @Test - void testResourceBuilderNameRequired() { - McpSchema.Annotations annotations = new McpSchema.Annotations( - Arrays.asList(McpSchema.Role.USER, McpSchema.Role.ASSISTANT), 0.8); - - McpSchema.Resource.Builder resourceBuilder = McpSchema.Resource.builder() - .uri("resource://test") - .description("A test resource") - .mimeType("text/plain") - .size(256L) - .annotations(annotations); - - assertThatThrownBy(resourceBuilder::build).isInstanceOf(java.lang.IllegalArgumentException.class); - } - - @Test - void testResourceTemplate() throws Exception { - McpSchema.Annotations annotations = new McpSchema.Annotations(Arrays.asList(McpSchema.Role.USER), 0.5); - - McpSchema.ResourceTemplate template = new McpSchema.ResourceTemplate("resource://{param}/test", "Test Template", - "A test resource template", "text/plain", annotations); - - String value = mapper.writeValueAsString(template); - assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo( - json(""" - {"uriTemplate":"resource://{param}/test","name":"Test Template","description":"A test resource template","mimeType":"text/plain","annotations":{"audience":["user"],"priority":0.5}}""")); - } - - @Test - void testListResourcesResult() throws Exception { - McpSchema.Resource resource1 = new McpSchema.Resource("resource://test1", "Test Resource 1", - "First test resource", "text/plain", null); - - McpSchema.Resource resource2 = new McpSchema.Resource("resource://test2", "Test Resource 2", - "Second test resource", "application/json", null); - - McpSchema.ListResourcesResult result = new McpSchema.ListResourcesResult(Arrays.asList(resource1, resource2), - "next-cursor"); - - String value = mapper.writeValueAsString(result); - assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo( - json(""" - {"resources":[{"uri":"resource://test1","name":"Test Resource 1","description":"First test resource","mimeType":"text/plain"},{"uri":"resource://test2","name":"Test Resource 2","description":"Second test resource","mimeType":"application/json"}],"nextCursor":"next-cursor"}""")); - } - - @Test - void testListResourceTemplatesResult() throws Exception { - McpSchema.ResourceTemplate template1 = new McpSchema.ResourceTemplate("resource://{param}/test1", - "Test Template 1", "First test template", "text/plain", null); - - McpSchema.ResourceTemplate template2 = new McpSchema.ResourceTemplate("resource://{param}/test2", - "Test Template 2", "Second test template", "application/json", null); - - McpSchema.ListResourceTemplatesResult result = new McpSchema.ListResourceTemplatesResult( - Arrays.asList(template1, template2), "next-cursor"); - - String value = mapper.writeValueAsString(result); - assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo( - json(""" - {"resourceTemplates":[{"uriTemplate":"resource://{param}/test1","name":"Test Template 1","description":"First test template","mimeType":"text/plain"},{"uriTemplate":"resource://{param}/test2","name":"Test Template 2","description":"Second test template","mimeType":"application/json"}],"nextCursor":"next-cursor"}""")); - } - - @Test - void testReadResourceRequest() throws Exception { - McpSchema.ReadResourceRequest request = new McpSchema.ReadResourceRequest("resource://test"); - - String value = mapper.writeValueAsString(request); - assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo(json(""" - {"uri":"resource://test"}""")); - } - - @Test - void testReadResourceResult() throws Exception { - McpSchema.TextResourceContents contents1 = new McpSchema.TextResourceContents("resource://test1", "text/plain", - "Sample text content"); - - McpSchema.BlobResourceContents contents2 = new McpSchema.BlobResourceContents("resource://test2", - "application/octet-stream", "base64encodedblob"); - - McpSchema.ReadResourceResult result = new McpSchema.ReadResourceResult(Arrays.asList(contents1, contents2)); - - String value = mapper.writeValueAsString(result); - assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo( - json(""" - {"contents":[{"uri":"resource://test1","mimeType":"text/plain","text":"Sample text content"},{"uri":"resource://test2","mimeType":"application/octet-stream","blob":"base64encodedblob"}]}""")); - } - - // Prompt Tests - - @Test - void testPrompt() throws Exception { - McpSchema.PromptArgument arg1 = new McpSchema.PromptArgument("arg1", "First argument", true); - - McpSchema.PromptArgument arg2 = new McpSchema.PromptArgument("arg2", "Second argument", false); - - McpSchema.Prompt prompt = new McpSchema.Prompt("test-prompt", "A test prompt", Arrays.asList(arg1, arg2)); - - String value = mapper.writeValueAsString(prompt); - assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo( - json(""" - {"name":"test-prompt","description":"A test prompt","arguments":[{"name":"arg1","description":"First argument","required":true},{"name":"arg2","description":"Second argument","required":false}]}""")); - } - - @Test - void testPromptMessage() throws Exception { - McpSchema.TextContent content = new McpSchema.TextContent("Hello, world!"); - - McpSchema.PromptMessage message = new McpSchema.PromptMessage(McpSchema.Role.USER, content); - - String value = mapper.writeValueAsString(message); - assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo(json(""" - {"role":"user","content":{"type":"text","text":"Hello, world!"}}""")); - } - - @Test - void testListPromptsResult() throws Exception { - McpSchema.PromptArgument arg = new McpSchema.PromptArgument("arg", "An argument", true); - - McpSchema.Prompt prompt1 = new McpSchema.Prompt("prompt1", "First prompt", Collections.singletonList(arg)); - - McpSchema.Prompt prompt2 = new McpSchema.Prompt("prompt2", "Second prompt", Collections.emptyList()); - - McpSchema.ListPromptsResult result = new McpSchema.ListPromptsResult(Arrays.asList(prompt1, prompt2), - "next-cursor"); - - String value = mapper.writeValueAsString(result); - assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo( - json(""" - {"prompts":[{"name":"prompt1","description":"First prompt","arguments":[{"name":"arg","description":"An argument","required":true}]},{"name":"prompt2","description":"Second prompt","arguments":[]}],"nextCursor":"next-cursor"}""")); - } - - @Test - void testGetPromptRequest() throws Exception { - Map arguments = new HashMap<>(); - arguments.put("arg1", "value1"); - arguments.put("arg2", 42); - - McpSchema.GetPromptRequest request = new McpSchema.GetPromptRequest("test-prompt", arguments); - - assertThat(mapper.readValue(""" - {"name":"test-prompt","arguments":{"arg1":"value1","arg2":42}}""", McpSchema.GetPromptRequest.class)) - .isEqualTo(request); - } - - @Test - void testGetPromptResult() throws Exception { - McpSchema.TextContent content1 = new McpSchema.TextContent("System message"); - McpSchema.TextContent content2 = new McpSchema.TextContent("User message"); - - McpSchema.PromptMessage message1 = new McpSchema.PromptMessage(McpSchema.Role.ASSISTANT, content1); - - McpSchema.PromptMessage message2 = new McpSchema.PromptMessage(McpSchema.Role.USER, content2); - - McpSchema.GetPromptResult result = new McpSchema.GetPromptResult("A test prompt result", - Arrays.asList(message1, message2)); - - String value = mapper.writeValueAsString(result); - - assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo( - json(""" - {"description":"A test prompt result","messages":[{"role":"assistant","content":{"type":"text","text":"System message"}},{"role":"user","content":{"type":"text","text":"User message"}}]}""")); - } - - // Tool Tests - - @Test - void testJsonSchema() throws Exception { - String schemaJson = """ - { - "type": "object", - "properties": { - "name": { - "type": "string" - }, - "address": { - "$ref": "#/$defs/Address" - } - }, - "required": ["name"], - "$defs": { - "Address": { - "type": "object", - "properties": { - "street": {"type": "string"}, - "city": {"type": "string"} - }, - "required": ["street", "city"] - } - } - } - """; - - // Deserialize the original string to a JsonSchema object - McpSchema.JsonSchema schema = mapper.readValue(schemaJson, McpSchema.JsonSchema.class); - - // Serialize the object back to a string - String serialized = mapper.writeValueAsString(schema); - - // Deserialize again - McpSchema.JsonSchema deserialized = mapper.readValue(serialized, McpSchema.JsonSchema.class); - - // Serialize one more time and compare with the first serialization - String serializedAgain = mapper.writeValueAsString(deserialized); - - // The two serialized strings should be the same - assertThatJson(serializedAgain).when(Option.IGNORING_ARRAY_ORDER).isEqualTo(json(serialized)); - } - - @Test - void testJsonSchemaWithDefinitions() throws Exception { - String schemaJson = """ - { - "type": "object", - "properties": { - "name": { - "type": "string" - }, - "address": { - "$ref": "#/definitions/Address" - } - }, - "required": ["name"], - "definitions": { - "Address": { - "type": "object", - "properties": { - "street": {"type": "string"}, - "city": {"type": "string"} - }, - "required": ["street", "city"] - } - } - } - """; - - // Deserialize the original string to a JsonSchema object - McpSchema.JsonSchema schema = mapper.readValue(schemaJson, McpSchema.JsonSchema.class); - - // Serialize the object back to a string - String serialized = mapper.writeValueAsString(schema); - - // Deserialize again - McpSchema.JsonSchema deserialized = mapper.readValue(serialized, McpSchema.JsonSchema.class); - - // Serialize one more time and compare with the first serialization - String serializedAgain = mapper.writeValueAsString(deserialized); - - // The two serialized strings should be the same - assertThatJson(serializedAgain).when(Option.IGNORING_ARRAY_ORDER).isEqualTo(json(serialized)); - } - - @Test - void testTool() throws Exception { - String schemaJson = """ - { - "type": "object", - "properties": { - "name": { - "type": "string" - }, - "value": { - "type": "number" - } - }, - "required": ["name"] - } - """; - - McpSchema.Tool tool = new McpSchema.Tool("test-tool", "A test tool", schemaJson); - - String value = mapper.writeValueAsString(tool); - assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo( - json(""" - {"name":"test-tool","description":"A test tool","inputSchema":{"type":"object","properties":{"name":{"type":"string"},"value":{"type":"number"}},"required":["name"]}}""")); - } - - @Test - void testToolWithComplexSchema() throws Exception { - String complexSchemaJson = """ - { - "type": "object", - "$defs": { - "Address": { - "type": "object", - "properties": { - "street": {"type": "string"}, - "city": {"type": "string"} - }, - "required": ["street", "city"] - } - }, - "properties": { - "name": {"type": "string"}, - "shippingAddress": {"$ref": "#/$defs/Address"} - }, - "required": ["name", "shippingAddress"] - } - """; - - McpSchema.Tool tool = new McpSchema.Tool("addressTool", "Handles addresses", complexSchemaJson); - - // Serialize the tool to a string - String serialized = mapper.writeValueAsString(tool); - - // Deserialize back to a Tool object - McpSchema.Tool deserializedTool = mapper.readValue(serialized, McpSchema.Tool.class); - - // Serialize again and compare with first serialization - String serializedAgain = mapper.writeValueAsString(deserializedTool); - - // The two serialized strings should be the same - assertThatJson(serializedAgain).when(Option.IGNORING_ARRAY_ORDER).isEqualTo(json(serialized)); - - // Just verify the basic structure was preserved - assertThat(deserializedTool.inputSchema().defs()).isNotNull(); - assertThat(deserializedTool.inputSchema().defs()).containsKey("Address"); - } - - @Test - void testToolWithAnnotations() throws Exception { - String schemaJson = """ - { - "type": "object", - "properties": { - "name": { - "type": "string" - }, - "value": { - "type": "number" - } - }, - "required": ["name"] - } - """; - McpSchema.ToolAnnotations annotations = new McpSchema.ToolAnnotations("A test tool", false, false, false, false, - false); - - McpSchema.Tool tool = new McpSchema.Tool("test-tool", "A test tool", schemaJson, annotations); - - String value = mapper.writeValueAsString(tool); - assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo( - json(""" - {"name":"test-tool","description":"A test tool","inputSchema":{"type":"object","properties":{"name":{"type":"string"},"value":{"type":"number"}},"required":["name"]},"annotations":{"title":"A test tool","readOnlyHint":false,"destructiveHint":false,"idempotentHint":false,"openWorldHint":false,"returnDirect":false}}""")); - } - - @Test - void testCallToolRequest() throws Exception { - Map arguments = new HashMap<>(); - arguments.put("name", "test"); - arguments.put("value", 42); - - McpSchema.CallToolRequest request = new McpSchema.CallToolRequest("test-tool", arguments); - - String value = mapper.writeValueAsString(request); - - assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo(json(""" - {"name":"test-tool","arguments":{"name":"test","value":42}}""")); - } - - @Test - void testCallToolRequestJsonArguments() throws Exception { - - McpSchema.CallToolRequest request = new McpSchema.CallToolRequest("test-tool", """ - { - "name": "test", - "value": 42 - } - """); - - String value = mapper.writeValueAsString(request); - - assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo(json(""" - {"name":"test-tool","arguments":{"name":"test","value":42}}""")); - } - - @Test - void testCallToolResult() throws Exception { - McpSchema.TextContent content = new McpSchema.TextContent("Tool execution result"); - - McpSchema.CallToolResult result = new McpSchema.CallToolResult(Collections.singletonList(content), false); - - String value = mapper.writeValueAsString(result); - - assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo(json(""" - {"content":[{"type":"text","text":"Tool execution result"}],"isError":false}""")); - } - - @Test - void testCallToolResultBuilder() throws Exception { - McpSchema.CallToolResult result = McpSchema.CallToolResult.builder() - .addTextContent("Tool execution result") - .isError(false) - .build(); - - String value = mapper.writeValueAsString(result); - - assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo(json(""" - {"content":[{"type":"text","text":"Tool execution result"}],"isError":false}""")); - } - - @Test - void testCallToolResultBuilderWithMultipleContents() throws Exception { - McpSchema.TextContent textContent = new McpSchema.TextContent("Text result"); - McpSchema.ImageContent imageContent = new McpSchema.ImageContent(null, null, "base64data", "image/png"); - - McpSchema.CallToolResult result = McpSchema.CallToolResult.builder() - .addContent(textContent) - .addContent(imageContent) - .isError(false) - .build(); - - String value = mapper.writeValueAsString(result); - - assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo( - json(""" - {"content":[{"type":"text","text":"Text result"},{"type":"image","data":"base64data","mimeType":"image/png"}],"isError":false}""")); - } - - @Test - void testCallToolResultBuilderWithContentList() throws Exception { - McpSchema.TextContent textContent = new McpSchema.TextContent("Text result"); - McpSchema.ImageContent imageContent = new McpSchema.ImageContent(null, null, "base64data", "image/png"); - List contents = Arrays.asList(textContent, imageContent); - - McpSchema.CallToolResult result = McpSchema.CallToolResult.builder().content(contents).isError(true).build(); - - String value = mapper.writeValueAsString(result); - - assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo( - json(""" - {"content":[{"type":"text","text":"Text result"},{"type":"image","data":"base64data","mimeType":"image/png"}],"isError":true}""")); - } - - @Test - void testCallToolResultBuilderWithErrorResult() throws Exception { - McpSchema.CallToolResult result = McpSchema.CallToolResult.builder() - .addTextContent("Error: Operation failed") - .isError(true) - .build(); - - String value = mapper.writeValueAsString(result); - - assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo(json(""" - {"content":[{"type":"text","text":"Error: Operation failed"}],"isError":true}""")); - } - - @Test - void testCallToolResultStringConstructor() throws Exception { - // Test the existing string constructor alongside the builder - McpSchema.CallToolResult result1 = new McpSchema.CallToolResult("Simple result", false); - McpSchema.CallToolResult result2 = McpSchema.CallToolResult.builder() - .addTextContent("Simple result") - .isError(false) - .build(); - - String value1 = mapper.writeValueAsString(result1); - String value2 = mapper.writeValueAsString(result2); - - // Both should produce the same JSON - assertThat(value1).isEqualTo(value2); - assertThatJson(value1).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo(json(""" - {"content":[{"type":"text","text":"Simple result"}],"isError":false}""")); - } - - // Sampling Tests - - @Test - void testCreateMessageRequest() throws Exception { - McpSchema.TextContent content = new McpSchema.TextContent("User message"); - - McpSchema.SamplingMessage message = new McpSchema.SamplingMessage(McpSchema.Role.USER, content); - - McpSchema.ModelHint hint = new McpSchema.ModelHint("gpt-4"); - - McpSchema.ModelPreferences preferences = new McpSchema.ModelPreferences(Collections.singletonList(hint), 0.3, - 0.7, 0.9); - - Map metadata = new HashMap<>(); - metadata.put("session", "test-session"); - - McpSchema.CreateMessageRequest request = McpSchema.CreateMessageRequest.builder() - .messages(Collections.singletonList(message)) - .modelPreferences(preferences) - .systemPrompt("You are a helpful assistant") - .includeContext(McpSchema.CreateMessageRequest.ContextInclusionStrategy.THIS_SERVER) - .temperature(0.7) - .maxTokens(1000) - .stopSequences(Arrays.asList("STOP", "END")) - .metadata(metadata) - .build(); - - String value = mapper.writeValueAsString(request); - - assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo( - json(""" - {"messages":[{"role":"user","content":{"type":"text","text":"User message"}}],"modelPreferences":{"hints":[{"name":"gpt-4"}],"costPriority":0.3,"speedPriority":0.7,"intelligencePriority":0.9},"systemPrompt":"You are a helpful assistant","includeContext":"thisServer","temperature":0.7,"maxTokens":1000,"stopSequences":["STOP","END"],"metadata":{"session":"test-session"}}""")); - } - - @Test - void testCreateMessageResult() throws Exception { - McpSchema.TextContent content = new McpSchema.TextContent("Assistant response"); - - McpSchema.CreateMessageResult result = McpSchema.CreateMessageResult.builder() - .role(McpSchema.Role.ASSISTANT) - .content(content) - .model("gpt-4") - .stopReason(McpSchema.CreateMessageResult.StopReason.END_TURN) - .build(); - - String value = mapper.writeValueAsString(result); - - assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo( - json(""" - {"role":"assistant","content":{"type":"text","text":"Assistant response"},"model":"gpt-4","stopReason":"endTurn"}""")); - } - - @Test - void testCreateMessageResultUnknownStopReason() throws Exception { - String input = """ - {"role":"assistant","content":{"type":"text","text":"Assistant response"},"model":"gpt-4","stopReason":"arbitrary value"}"""; - - McpSchema.CreateMessageResult value = mapper.readValue(input, McpSchema.CreateMessageResult.class); - - McpSchema.TextContent expectedContent = new McpSchema.TextContent("Assistant response"); - McpSchema.CreateMessageResult expected = McpSchema.CreateMessageResult.builder() - .role(McpSchema.Role.ASSISTANT) - .content(expectedContent) - .model("gpt-4") - .stopReason(McpSchema.CreateMessageResult.StopReason.UNKNOWN) - .build(); - assertThat(value).isEqualTo(expected); - } - - // Elicitation Tests - - @Test - void testCreateElicitationRequest() throws Exception { - McpSchema.ElicitRequest request = McpSchema.ElicitRequest.builder() - .requestedSchema(Map.of("type", "object", "required", List.of("a"), "properties", - Map.of("foo", Map.of("type", "string")))) - .build(); - - String value = mapper.writeValueAsString(request); - - assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo(json(""" - {"requestedSchema":{"properties":{"foo":{"type":"string"}},"required":["a"],"type":"object"}}""")); - } - - @Test - void testCreateElicitationResult() throws Exception { - McpSchema.ElicitResult result = McpSchema.ElicitResult.builder() - .content(Map.of("foo", "bar")) - .message(McpSchema.ElicitResult.Action.ACCEPT) - .build(); - - String value = mapper.writeValueAsString(result); - - assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo(json(""" - {"action":"accept","content":{"foo":"bar"}}""")); - } - - // Roots Tests - - @Test - void testRoot() throws Exception { - McpSchema.Root root = new McpSchema.Root("file:///path/to/root", "Test Root"); - - String value = mapper.writeValueAsString(root); - assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo(json(""" - {"uri":"file:///path/to/root","name":"Test Root"}""")); - } - - @Test - void testListRootsResult() throws Exception { - McpSchema.Root root1 = new McpSchema.Root("file:///path/to/root1", "First Root"); - - McpSchema.Root root2 = new McpSchema.Root("file:///path/to/root2", "Second Root"); - - McpSchema.ListRootsResult result = new McpSchema.ListRootsResult(Arrays.asList(root1, root2)); - - String value = mapper.writeValueAsString(result); - - assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo( - json(""" - {"roots":[{"uri":"file:///path/to/root1","name":"First Root"},{"uri":"file:///path/to/root2","name":"Second Root"}]}""")); - - } - -} diff --git a/pom.xml b/pom.xml index 3fd0857e8..f8bc3a9c2 100644 --- a/pom.xml +++ b/pom.xml @@ -6,7 +6,7 @@ io.modelcontextprotocol.sdk mcp-parent - 0.11.0-SNAPSHOT + 0.18.0-SNAPSHOT pom https://github.com/modelcontextprotocol/java-sdk @@ -59,23 +59,23 @@ 17 - 3.26.3 + 3.27.6 5.10.2 - 5.17.0 - 1.20.4 - 1.17.5 + 5.20.0 + 1.21.4 + 1.17.8 1.21.0 2.0.16 1.5.15 - 2.17.0 + 2.19.2 6.2.1 3.11.0 3.1.2 3.5.2 - 3.5.0 + 3.11.2 3.3.0 0.8.10 1.5.0 @@ -96,12 +96,16 @@ 4.2.0 7.1.0 4.1.0 + 2.0.0 mcp-bom mcp + mcp-core + mcp-json-jackson2 + mcp-json mcp-spring/mcp-spring-webflux mcp-spring/mcp-spring-webmvc mcp-test @@ -275,6 +279,7 @@ ${maven-javadoc-plugin.version} false + true false none