From 7827cdc113daa6bda9ea310ed27e9e877b616eb1 Mon Sep 17 00:00:00 2001 From: Christian Tzolov <1351573+tzolov@users.noreply.github.com> Date: Wed, 26 Mar 2025 12:09:46 +0100 Subject: [PATCH 001/205] refactor(server): Fi StdioServerTransportProvider initialization flow (#74) Extract message processing initialization from StdioMcpSessionTransport constructor into a separate initProcessing() method. Signed-off-by: Christian Tzolov --- .../transport/StdioServerTransportProvider.java | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java index 6a7d29039..a8b980e90 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java @@ -93,7 +93,9 @@ public StdioServerTransportProvider(ObjectMapper objectMapper, InputStream input @Override public void setSessionFactory(McpServerSession.Factory sessionFactory) { // Create a single session for the stdio connection - this.session = sessionFactory.create(new StdioMcpSessionTransport()); + var transport = new StdioMcpSessionTransport(); + this.session = sessionFactory.create(transport); + transport.initProcessing(); } @Override @@ -142,10 +144,6 @@ public StdioMcpSessionTransport() { "stdio-inbound"); this.outboundScheduler = Schedulers.fromExecutorService(Executors.newSingleThreadExecutor(), "stdio-outbound"); - - handleIncomingMessages(); - startInboundProcessing(); - startOutboundProcessing(); } @Override @@ -181,6 +179,12 @@ public void close() { logger.debug("Session transport closed"); } + private void initProcessing() { + handleIncomingMessages(); + startInboundProcessing(); + startOutboundProcessing(); + } + private void handleIncomingMessages() { this.inboundSink.asFlux().flatMap(message -> session.handle(message)).doOnTerminate(() -> { // The outbound processing will dispose its scheduler upon completion From 25f3bad68d83367833f81da81714c3b0dcc7dcbd Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Sat, 22 Mar 2025 16:50:33 +0100 Subject: [PATCH 002/205] refactor: remove deprecated 0.7.0 code These changes are part of the planned deprecation cycle announced in 0.8.0, with the deprecated classes scheduled for removal in 0.9.0 - Delete WebFluxSseServerTransport, WebMvcSseServerTransport, StdioServerTransport, and HttpServletSseServerTransport - Remove deprecated interfaces: ServerMcpTransport, ClientMcpTransport - Delete DefaultMcpSession implementation - Remove all deprecated test classes for the removed implementations - Update references to use McpServerTransport and McpClientTransport interfaces - Split MockMcpTransport into client and server implementations * Rename MockMcpTransport to MockMcpClientTransport in mcp/src/test * Create new MockMcpServerTransport implementation * Add MockMcpServerTransportProvider for server tests * Mark MockMcpTransport in mcp-test module as deprecated * Update all test classes to use the new implementations Signed-off-by: Christian Tzolov --- .../transport/WebFluxSseServerTransport.java | 413 --------- .../WebFluxSseServerTransportProvider.java | 4 +- ...bFluxSseMcpAsyncServerDeprecatedTests.java | 55 -- ...ebFluxSseMcpSyncServerDeprecatecTests.java | 55 -- .../legacy/WebFluxSseIntegrationTests.java | 459 ---------- .../transport/WebMvcSseServerTransport.java | 385 -------- ...seAsyncServerTransportDeprecatedTests.java | 118 --- .../WebMvcSseIntegrationDeprecatedTests.java | 508 ----------- ...SseSyncServerTransportDeprecatedTests.java | 118 --- .../MockMcpTransport.java | 9 +- ...AbstractMcpAsyncServerDeprecatedTests.java | 465 ---------- .../AbstractMcpSyncServerDeprecatedTests.java | 431 --------- .../client/McpAsyncClient.java | 4 +- .../client/McpClient.java | 49 +- .../client/McpSyncClient.java | 6 +- .../server/McpAsyncServer.java | 814 +---------------- .../server/McpServer.java | 851 +----------------- .../server/McpServerFeatures.java | 269 ------ .../server/McpSyncServer.java | 110 --- .../HttpServletSseServerTransport.java | 419 --------- .../transport/StdioServerTransport.java | 259 ------ .../spec/ClientMcpTransport.java | 15 - .../spec/DefaultMcpSession.java | 291 ------ .../spec/McpClientSession.java | 4 +- .../spec/McpClientTransport.java | 3 +- .../spec/McpTransport.java | 17 - .../spec/ServerMcpTransport.java | 15 - ...sport.java => MockMcpClientTransport.java} | 12 +- .../MockMcpServerTransport.java | 66 ++ .../MockMcpServerTransportProvider.java | 63 ++ .../McpAsyncClientResponseHandlerTests.java | 26 +- .../client/McpClientProtocolVersionTests.java | 10 +- ...AbstractMcpAsyncServerDeprecatedTests.java | 466 ---------- .../AbstractMcpSyncServerDeprecatedTests.java | 433 --------- .../server/McpServerProtocolVersionTests.java | 43 +- ...rvletSseMcpAsyncServerDeprecatedTests.java | 26 - ...ervletSseMcpSyncServerDeprecatedTests.java | 26 - .../StdioMcpAsyncServerDeprecatedTests.java | 25 - .../server/StdioMcpAsyncServerTests.java | 1 - .../StdioMcpSyncServerDeprecatedTests.java | 25 - ...letSseServerTransportIntegrationTests.java | 328 ------- .../transport/StdioServerTransportTests.java | 157 ---- .../spec/McpClientSessionTests.java | 10 +- 43 files changed, 203 insertions(+), 7660 deletions(-) delete mode 100644 mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransport.java delete mode 100644 mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerDeprecatedTests.java delete mode 100644 mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerDeprecatecTests.java delete mode 100644 mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/legacy/WebFluxSseIntegrationTests.java delete mode 100644 mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransport.java delete mode 100644 mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportDeprecatedTests.java delete mode 100644 mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationDeprecatedTests.java delete mode 100644 mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseSyncServerTransportDeprecatedTests.java delete mode 100644 mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerDeprecatedTests.java delete mode 100644 mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerDeprecatedTests.java delete mode 100644 mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransport.java delete mode 100644 mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransport.java delete mode 100644 mcp/src/main/java/io/modelcontextprotocol/spec/ClientMcpTransport.java delete mode 100644 mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpSession.java delete mode 100644 mcp/src/main/java/io/modelcontextprotocol/spec/ServerMcpTransport.java rename mcp/src/test/java/io/modelcontextprotocol/{MockMcpTransport.java => MockMcpClientTransport.java} (84%) create mode 100644 mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransport.java create mode 100644 mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java delete mode 100644 mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerDeprecatedTests.java delete mode 100644 mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerDeprecatedTests.java delete mode 100644 mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpAsyncServerDeprecatedTests.java delete mode 100644 mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpSyncServerDeprecatedTests.java delete mode 100644 mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpAsyncServerDeprecatedTests.java delete mode 100644 mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpSyncServerDeprecatedTests.java delete mode 100644 mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportIntegrationTests.java delete mode 100644 mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportTests.java diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransport.java deleted file mode 100644 index fb0b581e0..000000000 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransport.java +++ /dev/null @@ -1,413 +0,0 @@ -package io.modelcontextprotocol.server.transport; - -import java.io.IOException; -import java.time.Duration; -import java.util.List; -import java.util.UUID; -import java.util.concurrent.ConcurrentHashMap; -import java.util.function.Function; - -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.ServerMcpTransport; -import io.modelcontextprotocol.util.Assert; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; -import reactor.core.publisher.Sinks; - -import org.springframework.http.HttpStatus; -import org.springframework.http.MediaType; -import org.springframework.http.codec.ServerSentEvent; -import org.springframework.web.reactive.function.server.RouterFunction; -import org.springframework.web.reactive.function.server.RouterFunctions; -import org.springframework.web.reactive.function.server.ServerRequest; -import org.springframework.web.reactive.function.server.ServerResponse; - -/** - * Server-side implementation of the MCP (Model Context Protocol) HTTP transport using - * Server-Sent Events (SSE). This implementation provides a bidirectional communication - * channel between MCP clients and servers using HTTP POST for client-to-server messages - * and SSE for server-to-client messages. - * - *

- * Key features: - *

    - *
  • Implements the {@link ServerMcpTransport} interface for MCP server transport - * functionality
  • - *
  • Uses WebFlux for non-blocking request handling and SSE support
  • - *
  • Maintains client sessions for reliable message delivery
  • - *
  • Supports graceful shutdown with session cleanup
  • - *
  • Thread-safe message broadcasting to multiple clients
  • - *
- * - *

- * The transport sets up two main endpoints: - *

    - *
  • SSE endpoint (/sse) - For establishing SSE connections with clients
  • - *
  • Message endpoint (configurable) - For receiving JSON-RPC messages from clients
  • - *
- * - *

- * This implementation is thread-safe and can handle multiple concurrent client - * connections. It uses {@link ConcurrentHashMap} for session management and Reactor's - * {@link Sinks} for thread-safe message broadcasting. - * - * @author Christian Tzolov - * @author Alexandros Pappas - * @see ServerMcpTransport - * @see ServerSentEvent - * @deprecated This class will be removed in 0.9.0. Use - * {@link WebFluxSseServerTransportProvider}. - */ -@Deprecated -public class WebFluxSseServerTransport implements ServerMcpTransport { - - private static final Logger logger = LoggerFactory.getLogger(WebFluxSseServerTransport.class); - - /** - * Event type for JSON-RPC messages sent through the SSE connection. - */ - public static final String MESSAGE_EVENT_TYPE = "message"; - - /** - * Event type for sending the message endpoint URI to clients. - */ - public static final String ENDPOINT_EVENT_TYPE = "endpoint"; - - /** - * Default SSE endpoint path as specified by the MCP transport specification. - */ - public static final String DEFAULT_SSE_ENDPOINT = "/sse"; - - private final ObjectMapper objectMapper; - - private final String messageEndpoint; - - private final String sseEndpoint; - - private final RouterFunction routerFunction; - - /** - * Map of active client sessions, keyed by session ID. - */ - private final ConcurrentHashMap sessions = new ConcurrentHashMap<>(); - - /** - * Flag indicating if the transport is shutting down. - */ - private volatile boolean isClosing = false; - - private Function, Mono> connectHandler; - - /** - * Constructs a new WebFlux SSE server transport instance. - * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization - * of MCP messages. Must not be null. - * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC - * messages. This endpoint will be communicated to clients during SSE connection - * setup. Must not be null. - * @throws IllegalArgumentException if either parameter is null - */ - public WebFluxSseServerTransport(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) { - Assert.notNull(objectMapper, "ObjectMapper must not be null"); - Assert.notNull(messageEndpoint, "Message endpoint must not be null"); - Assert.notNull(sseEndpoint, "SSE endpoint must not be null"); - - this.objectMapper = objectMapper; - this.messageEndpoint = messageEndpoint; - this.sseEndpoint = sseEndpoint; - this.routerFunction = RouterFunctions.route() - .GET(this.sseEndpoint, this::handleSseConnection) - .POST(this.messageEndpoint, this::handleMessage) - .build(); - } - - /** - * Constructs a new WebFlux SSE server transport instance with the default SSE - * endpoint. - * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization - * of MCP messages. Must not be null. - * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC - * messages. This endpoint will be communicated to clients during SSE connection - * setup. Must not be null. - * @throws IllegalArgumentException if either parameter is null - */ - public WebFluxSseServerTransport(ObjectMapper objectMapper, String messageEndpoint) { - this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT); - } - - /** - * Configures the message handler for this transport. In the WebFlux SSE - * implementation, this method stores the handler for processing incoming messages but - * doesn't establish any connections since the server accepts connections rather than - * initiating them. - * @param handler A function that processes incoming JSON-RPC messages and returns - * responses. This handler will be called for each message received through the - * message endpoint. - * @return An empty Mono since the server doesn't initiate connections - */ - @Override - public Mono connect(Function, Mono> handler) { - this.connectHandler = handler; - // Server-side transport doesn't initiate connections - return Mono.empty().then(); - } - - /** - * Broadcasts a JSON-RPC message to all connected clients through their SSE - * connections. The message is serialized to JSON and sent as a server-sent event to - * each active session. - * - *

- * The method: - *

    - *
  • Serializes the message to JSON
  • - *
  • Creates a server-sent event with the message data
  • - *
  • Attempts to send the event to all active sessions
  • - *
  • Tracks and reports any delivery failures
  • - *
- * @param message The JSON-RPC message to broadcast - * @return A Mono that completes when the message has been sent to all sessions, or - * errors if any session fails to receive the message - */ - @Override - public Mono sendMessage(McpSchema.JSONRPCMessage message) { - if (sessions.isEmpty()) { - logger.debug("No active sessions to broadcast message to"); - return Mono.empty(); - } - - return Mono.create(sink -> { - try {// @formatter:off - String jsonText = objectMapper.writeValueAsString(message); - ServerSentEvent event = ServerSentEvent.builder() - .event(MESSAGE_EVENT_TYPE) - .data(jsonText) - .build(); - - logger.debug("Attempting to broadcast message to {} active sessions", sessions.size()); - - List failedSessions = sessions.values().stream() - .filter(session -> session.messageSink.tryEmitNext(event).isFailure()) - .map(session -> session.id) - .toList(); - - if (failedSessions.isEmpty()) { - logger.debug("Successfully broadcast message to all sessions"); - sink.success(); - } - else { - String error = "Failed to broadcast message to sessions: " + String.join(", ", failedSessions); - logger.error(error); - sink.error(new RuntimeException(error)); - } // @formatter:on - } - catch (IOException e) { - logger.error("Failed to serialize message: {}", e.getMessage()); - sink.error(e); - } - }); - } - - /** - * Converts data from one type to another using the configured ObjectMapper. This - * method is primarily used for converting between different representations of - * JSON-RPC message data. - * @param The target type to convert to - * @param data The source data to convert - * @param typeRef Type reference describing the target type - * @return The converted data - * @throws IllegalArgumentException if the conversion fails - */ - @Override - public T unmarshalFrom(Object data, TypeReference typeRef) { - return this.objectMapper.convertValue(data, typeRef); - } - - /** - * Initiates a graceful shutdown of the transport. This method ensures all active - * sessions are properly closed and cleaned up. - * - *

- * The shutdown process: - *

    - *
  • Marks the transport as closing to prevent new connections
  • - *
  • Closes each active session
  • - *
  • Removes closed sessions from the sessions map
  • - *
  • Times out after 5 seconds if shutdown takes too long
  • - *
- * @return A Mono that completes when all sessions have been closed - */ - @Override - public Mono closeGracefully() { - return Mono.fromRunnable(() -> { - isClosing = true; - logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size()); - }).then(Mono.when(sessions.values().stream().map(session -> { - String sessionId = session.id; - return Mono.fromRunnable(() -> session.close()) - .then(Mono.delay(Duration.ofMillis(100))) - .then(Mono.fromRunnable(() -> sessions.remove(sessionId))); - }).toList())) - .timeout(Duration.ofSeconds(5)) - .doOnSuccess(v -> logger.debug("Graceful shutdown completed")) - .doOnError(e -> logger.error("Error during graceful shutdown: {}", e.getMessage())); - } - - /** - * Returns the WebFlux router function that defines the transport's HTTP endpoints. - * This router function should be integrated into the application's web configuration. - * - *

- * The router function defines two endpoints: - *

    - *
  • GET {sseEndpoint} - For establishing SSE connections
  • - *
  • POST {messageEndpoint} - For receiving client messages
  • - *
- * @return The configured {@link RouterFunction} for handling HTTP requests - */ - public RouterFunction getRouterFunction() { - return this.routerFunction; - } - - /** - * Handles new SSE connection requests from clients. Creates a new session for each - * connection and sets up the SSE event stream. - * - *

- * The handler performs the following steps: - *

    - *
  • Generates a unique session ID
  • - *
  • Creates a new ClientSession instance
  • - *
  • Sends the message endpoint URI as an initial event
  • - *
  • Sets up message forwarding for the session
  • - *
  • Handles connection cleanup on completion or errors
  • - *
- * @param request The incoming server request - * @return A response with the SSE event stream - */ - private Mono handleSseConnection(ServerRequest request) { - if (isClosing) { - return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down"); - } - String sessionId = UUID.randomUUID().toString(); - logger.debug("Creating new SSE connection for session: {}", sessionId); - ClientSession session = new ClientSession(sessionId); - this.sessions.put(sessionId, session); - - return ServerResponse.ok() - .contentType(MediaType.TEXT_EVENT_STREAM) - .body(Flux.>create(sink -> { - // Send initial endpoint event - logger.debug("Sending initial endpoint event to session: {}", sessionId); - sink.next(ServerSentEvent.builder().event(ENDPOINT_EVENT_TYPE).data(messageEndpoint).build()); - - // Subscribe to session messages - session.messageSink.asFlux() - .doOnSubscribe(s -> logger.debug("Session {} subscribed to message sink", sessionId)) - .doOnComplete(() -> { - logger.debug("Session {} completed", sessionId); - sessions.remove(sessionId); - }) - .doOnError(error -> { - logger.error("Error in session {}: {}", sessionId, error.getMessage()); - sessions.remove(sessionId); - }) - .doOnCancel(() -> { - logger.debug("Session {} cancelled", sessionId); - sessions.remove(sessionId); - }) - .subscribe(event -> { - logger.debug("Forwarding event to session {}: {}", sessionId, event); - sink.next(event); - }, sink::error, sink::complete); - - sink.onCancel(() -> { - logger.debug("Session {} cancelled", sessionId); - sessions.remove(sessionId); - }); - }), ServerSentEvent.class); - } - - /** - * Handles incoming JSON-RPC messages from clients. Deserializes the message and - * processes it through the configured message handler. - * - *

- * The handler: - *

    - *
  • Deserializes the incoming JSON-RPC message
  • - *
  • Passes it through the message handler chain
  • - *
  • Returns appropriate HTTP responses based on processing results
  • - *
  • Handles various error conditions with appropriate error responses
  • - *
- * @param request The incoming server request containing the JSON-RPC message - * @return A response indicating the message processing result - */ - private Mono handleMessage(ServerRequest request) { - if (isClosing) { - return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down"); - } - - return request.bodyToMono(String.class).flatMap(body -> { - try { - McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body); - return Mono.just(message) - .transform(this.connectHandler) - .flatMap(response -> ServerResponse.ok().build()) - .onErrorResume(error -> { - logger.error("Error processing message: {}", error.getMessage()); - return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR) - .bodyValue(new McpError(error.getMessage())); - }); - } - catch (IllegalArgumentException | IOException e) { - logger.error("Failed to deserialize message: {}", e.getMessage()); - return ServerResponse.badRequest().bodyValue(new McpError("Invalid message format")); - } - }); - } - - /** - * Represents an active client SSE connection session. Manages the message sink for - * sending events to the client and handles session lifecycle. - * - *

- * Each session: - *

    - *
  • Has a unique identifier
  • - *
  • Maintains its own message sink for event broadcasting
  • - *
  • Supports clean shutdown through the close method
  • - *
- */ - private static class ClientSession { - - private final String id; - - private final Sinks.Many> messageSink; - - ClientSession(String id) { - this.id = id; - logger.debug("Creating new session: {}", id); - this.messageSink = Sinks.many().replay().latest(); - logger.debug("Session {} initialized with replay sink", id); - } - - void close() { - logger.debug("Closing session: {}", id); - Sinks.EmitResult result = messageSink.tryEmitComplete(); - if (result.isFailure()) { - logger.warn("Failed to complete message sink for session {}: {}", id, result); - } - else { - logger.debug("Successfully completed message sink for session {}", id); - } - } - - } - -} \ No newline at end of file diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java index cf3eeae03..4e5d2fafb 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java @@ -8,10 +8,9 @@ import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpServerSession; import io.modelcontextprotocol.spec.McpServerTransport; import io.modelcontextprotocol.spec.McpServerTransportProvider; -import io.modelcontextprotocol.spec.McpServerSession; -import io.modelcontextprotocol.spec.ServerMcpTransport; import io.modelcontextprotocol.util.Assert; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -19,7 +18,6 @@ import reactor.core.publisher.Flux; import reactor.core.publisher.FluxSink; import reactor.core.publisher.Mono; -import reactor.core.publisher.Sinks; import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerDeprecatedTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerDeprecatedTests.java deleted file mode 100644 index b460284ee..000000000 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerDeprecatedTests.java +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.server; - -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.server.transport.WebFluxSseServerTransport; -import io.modelcontextprotocol.spec.ServerMcpTransport; -import org.junit.jupiter.api.Timeout; -import reactor.netty.DisposableServer; -import reactor.netty.http.server.HttpServer; - -import org.springframework.http.server.reactive.HttpHandler; -import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; -import org.springframework.web.reactive.function.server.RouterFunctions; - -/** - * Tests for {@link McpAsyncServer} using {@link WebFluxSseServerTransport}. - * - * @author Christian Tzolov - */ -@Deprecated -@Timeout(15) // Giving extra time beyond the client timeout -class WebFluxSseMcpAsyncServerDeprecatedTests extends AbstractMcpAsyncServerDeprecatedTests { - - private static final int PORT = 8181; - - private static final String MESSAGE_ENDPOINT = "/mcp/message"; - - private DisposableServer httpServer; - - @Override - protected ServerMcpTransport createMcpTransport() { - var transport = new WebFluxSseServerTransport(new ObjectMapper(), MESSAGE_ENDPOINT); - - HttpHandler httpHandler = RouterFunctions.toHttpHandler(transport.getRouterFunction()); - ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); - httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); - - return transport; - } - - @Override - protected void onStart() { - } - - @Override - protected void onClose() { - if (httpServer != null) { - httpServer.disposeNow(); - } - } - -} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerDeprecatecTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerDeprecatecTests.java deleted file mode 100644 index be2bf6c7f..000000000 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerDeprecatecTests.java +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.server; - -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.server.transport.WebFluxSseServerTransport; -import io.modelcontextprotocol.spec.ServerMcpTransport; -import org.junit.jupiter.api.Timeout; -import reactor.netty.DisposableServer; -import reactor.netty.http.server.HttpServer; - -import org.springframework.http.server.reactive.HttpHandler; -import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; -import org.springframework.web.reactive.function.server.RouterFunctions; - -/** - * Tests for {@link McpSyncServer} using {@link WebFluxSseServerTransport}. - * - * @author Christian Tzolov - */ -@Deprecated -@Timeout(15) // Giving extra time beyond the client timeout -class WebFluxSseMcpSyncServerDeprecatecTests extends AbstractMcpSyncServerDeprecatedTests { - - private static final int PORT = 8182; - - private static final String MESSAGE_ENDPOINT = "/mcp/message"; - - private DisposableServer httpServer; - - private WebFluxSseServerTransport transport; - - @Override - protected ServerMcpTransport createMcpTransport() { - transport = new WebFluxSseServerTransport(new ObjectMapper(), MESSAGE_ENDPOINT); - return transport; - } - - @Override - protected void onStart() { - HttpHandler httpHandler = RouterFunctions.toHttpHandler(transport.getRouterFunction()); - ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); - httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); - } - - @Override - protected void onClose() { - if (httpServer != null) { - httpServer.disposeNow(); - } - } - -} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/legacy/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/legacy/WebFluxSseIntegrationTests.java deleted file mode 100644 index 981e114c9..000000000 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/legacy/WebFluxSseIntegrationTests.java +++ /dev/null @@ -1,459 +0,0 @@ -/* - * Copyright 2024 - 2024 the original author or authors. - */ -package io.modelcontextprotocol.server.legacy; - -import java.time.Duration; -import java.util.List; -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Function; - -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.client.McpClient; -import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; -import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; -import io.modelcontextprotocol.server.McpServer; -import io.modelcontextprotocol.server.McpServerFeatures; -import io.modelcontextprotocol.server.transport.WebFluxSseServerTransport; -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpSchema.CallToolResult; -import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; -import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; -import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; -import io.modelcontextprotocol.spec.McpSchema.InitializeResult; -import io.modelcontextprotocol.spec.McpSchema.Role; -import io.modelcontextprotocol.spec.McpSchema.Root; -import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; -import io.modelcontextprotocol.spec.McpSchema.Tool; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; -import reactor.netty.DisposableServer; -import reactor.netty.http.server.HttpServer; -import reactor.test.StepVerifier; - -import org.springframework.http.server.reactive.HttpHandler; -import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; -import org.springframework.web.client.RestClient; -import org.springframework.web.reactive.function.client.WebClient; -import org.springframework.web.reactive.function.server.RouterFunctions; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.awaitility.Awaitility.await; - -public class WebFluxSseIntegrationTests { - - private static final int PORT = 8182; - - private static final String MESSAGE_ENDPOINT = "/mcp/message"; - - private DisposableServer httpServer; - - private WebFluxSseServerTransport mcpServerTransport; - - ConcurrentHashMap clientBulders = new ConcurrentHashMap<>(); - - @BeforeEach - public void before() { - - this.mcpServerTransport = new WebFluxSseServerTransport(new ObjectMapper(), MESSAGE_ENDPOINT); - - HttpHandler httpHandler = RouterFunctions.toHttpHandler(mcpServerTransport.getRouterFunction()); - ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); - this.httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); - - clientBulders.put("httpclient", McpClient.sync(new HttpClientSseClientTransport("http://localhost:" + PORT))); - clientBulders.put("webflux", - McpClient.sync(new WebFluxSseClientTransport(WebClient.builder().baseUrl("http://localhost:" + PORT)))); - - } - - @AfterEach - public void after() { - if (httpServer != null) { - httpServer.disposeNow(); - } - } - - // --------------------------------------- - // Sampling Tests - // --------------------------------------- - @Test - void testCreateMessageWithoutInitialization() { - var mcpAsyncServer = McpServer.async(mcpServerTransport).serverInfo("test-server", "1.0.0").build(); - - var messages = List.of(new McpSchema.SamplingMessage(Role.USER, new McpSchema.TextContent("Test message"))); - var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); - - var request = new CreateMessageRequest(messages, modelPrefs, null, - CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), Map.of()); - - StepVerifier.create(mcpAsyncServer.createMessage(request)).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized. Call the initialize method first!"); - }); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testCreateMessageWithoutSamplingCapabilities(String clientType) { - - var mcpAsyncServer = McpServer.async(mcpServerTransport).serverInfo("test-server", "1.0.0").build(); - - var clientBuilder = clientBulders.get(clientType); - - var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")).build(); - - InitializeResult initResult = client.initialize(); - assertThat(initResult).isNotNull(); - - var messages = List.of(new McpSchema.SamplingMessage(Role.USER, new McpSchema.TextContent("Test message"))); - var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); - - var request = new CreateMessageRequest(messages, modelPrefs, null, - CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), Map.of()); - - StepVerifier.create(mcpAsyncServer.createMessage(request)).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Client must be configured with sampling capabilities"); - }); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testCreateMessageSuccess(String clientType) throws InterruptedException { - - var clientBuilder = clientBulders.get(clientType); - - var mcpAsyncServer = McpServer.async(mcpServerTransport).serverInfo("test-server", "1.0.0").build(); - - Function samplingHandler = request -> { - assertThat(request.messages()).hasSize(1); - assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); - - return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", - CreateMessageResult.StopReason.STOP_SEQUENCE); - }; - - var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().sampling().build()) - .sampling(samplingHandler) - .build(); - - InitializeResult initResult = client.initialize(); - assertThat(initResult).isNotNull(); - - var messages = List.of(new McpSchema.SamplingMessage(Role.USER, new McpSchema.TextContent("Test message"))); - var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); - - var request = new CreateMessageRequest(messages, modelPrefs, null, - CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), Map.of()); - - StepVerifier.create(mcpAsyncServer.createMessage(request)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.role()).isEqualTo(Role.USER); - assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); - assertThat(result.model()).isEqualTo("MockModelName"); - assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); - }).verifyComplete(); - } - - // --------------------------------------- - // Roots Tests - // --------------------------------------- - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testRootsSuccess(String clientType) { - var clientBuilder = clientBulders.get(clientType); - - List roots = List.of(new Root("uri1://", "root1"), new Root("uri2://", "root2")); - - AtomicReference> rootsRef = new AtomicReference<>(); - var mcpServer = McpServer.sync(mcpServerTransport) - .rootsChangeConsumer(rootsUpdate -> rootsRef.set(rootsUpdate)) - .build(); - - var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) - .roots(roots) - .build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThat(rootsRef.get()).isNull(); - - assertThat(mcpServer.listRoots().roots()).containsAll(roots); - - mcpClient.rootsListChangedNotification(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(roots); - }); - - // Remove a root - mcpClient.removeRoot(roots.get(0).uri()); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(roots.get(1))); - }); - - // Add a new root - var root3 = new Root("uri3://", "root3"); - mcpClient.addRoot(root3); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(roots.get(1), root3)); - }); - - mcpClient.close(); - mcpServer.close(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testRootsWithoutCapability(String clientType) { - var clientBuilder = clientBulders.get(clientType); - - var mcpServer = McpServer.sync(mcpServerTransport).rootsChangeConsumer(rootsUpdate -> { - }).build(); - - // Create client without roots capability - var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()) // No - // roots - // capability - .build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - // Attempt to list roots should fail - assertThatThrownBy(() -> mcpServer.listRoots().roots()).isInstanceOf(McpError.class) - .hasMessage("Roots not supported"); - - mcpClient.close(); - mcpServer.close(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testRootsWithEmptyRootsList(String clientType) { - var clientBuilder = clientBulders.get(clientType); - - AtomicReference> rootsRef = new AtomicReference<>(); - var mcpServer = McpServer.sync(mcpServerTransport) - .rootsChangeConsumer(rootsUpdate -> rootsRef.set(rootsUpdate)) - .build(); - - var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) - .roots(List.of()) // Empty roots list - .build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - mcpClient.rootsListChangedNotification(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).isEmpty(); - }); - - mcpClient.close(); - mcpServer.close(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testRootsWithMultipleConsumers(String clientType) { - var clientBuilder = clientBulders.get(clientType); - - List roots = List.of(new Root("uri1://", "root1")); - - AtomicReference> rootsRef1 = new AtomicReference<>(); - AtomicReference> rootsRef2 = new AtomicReference<>(); - - var mcpServer = McpServer.sync(mcpServerTransport) - .rootsChangeConsumer(rootsUpdate -> rootsRef1.set(rootsUpdate)) - .rootsChangeConsumer(rootsUpdate -> rootsRef2.set(rootsUpdate)) - .build(); - - var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) - .roots(roots) - .build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - mcpClient.rootsListChangedNotification(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef1.get()).containsAll(roots); - assertThat(rootsRef2.get()).containsAll(roots); - }); - - mcpClient.close(); - mcpServer.close(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testRootsServerCloseWithActiveSubscription(String clientType) { - - var clientBuilder = clientBulders.get(clientType); - - List roots = List.of(new Root("uri1://", "root1")); - - AtomicReference> rootsRef = new AtomicReference<>(); - var mcpServer = McpServer.sync(mcpServerTransport) - .rootsChangeConsumer(rootsUpdate -> rootsRef.set(rootsUpdate)) - .build(); - - var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) - .roots(roots) - .build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - mcpClient.rootsListChangedNotification(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(roots); - }); - - // Close server while subscription is active - mcpServer.close(); - - // Verify client can handle server closure gracefully - mcpClient.close(); - } - - // --------------------------------------- - // Tools Tests - // --------------------------------------- - - String emptyJsonSchema = """ - { - "$schema": "http://json-schema.org/draft-07/schema#", - "type": "object", - "properties": {} - } - """; - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testToolCallSuccess(String clientType) { - - var clientBuilder = clientBulders.get(clientType); - - var callResponse = new CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); - McpServerFeatures.SyncToolRegistration tool1 = new McpServerFeatures.SyncToolRegistration( - new Tool("tool1", "tool1 description", emptyJsonSchema), request -> { - // perform a blocking call to a remote service - String response = RestClient.create() - .get() - .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") - .retrieve() - .body(String.class); - assertThat(response).isNotBlank(); - return callResponse; - }); - - var mcpServer = McpServer.sync(mcpServerTransport) - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool1) - .build(); - - var mcpClient = clientBuilder.build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); - - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); - - mcpClient.close(); - mcpServer.close(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testToolListChangeHandlingSuccess(String clientType) { - - var clientBuilder = clientBulders.get(clientType); - - var callResponse = new CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); - McpServerFeatures.SyncToolRegistration tool1 = new McpServerFeatures.SyncToolRegistration( - new Tool("tool1", "tool1 description", emptyJsonSchema), request -> { - // perform a blocking call to a remote service - String response = RestClient.create() - .get() - .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") - .retrieve() - .body(String.class); - assertThat(response).isNotBlank(); - return callResponse; - }); - - var mcpServer = McpServer.sync(mcpServerTransport) - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool1) - .build(); - - AtomicReference> rootsRef = new AtomicReference<>(); - var mcpClient = clientBuilder.toolsChangeConsumer(toolsUpdate -> { - // perform a blocking call to a remote service - String response = RestClient.create() - .get() - .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") - .retrieve() - .body(String.class); - assertThat(response).isNotBlank(); - rootsRef.set(toolsUpdate); - }).build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThat(rootsRef.get()).isNull(); - - assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); - - mcpServer.notifyToolsListChanged(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(tool1.tool())); - }); - - // Remove a tool - mcpServer.removeTool("tool1"); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).isEmpty(); - }); - - // Add a new tool - McpServerFeatures.SyncToolRegistration tool2 = new McpServerFeatures.SyncToolRegistration( - new Tool("tool2", "tool2 description", emptyJsonSchema), request -> callResponse); - - mcpServer.addTool(tool2); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(tool2.tool())); - }); - - mcpClient.close(); - mcpServer.close(); - } - -} diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransport.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransport.java deleted file mode 100644 index 23193d106..000000000 --- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransport.java +++ /dev/null @@ -1,385 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.server.transport; - -import java.io.IOException; -import java.time.Duration; -import java.util.UUID; -import java.util.concurrent.ConcurrentHashMap; -import java.util.function.Function; - -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.ServerMcpTransport; -import io.modelcontextprotocol.util.Assert; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import reactor.core.publisher.Mono; - -import org.springframework.http.HttpStatus; -import org.springframework.web.servlet.function.RouterFunction; -import org.springframework.web.servlet.function.RouterFunctions; -import org.springframework.web.servlet.function.ServerRequest; -import org.springframework.web.servlet.function.ServerResponse; -import org.springframework.web.servlet.function.ServerResponse.SseBuilder; - -/** - * Server-side implementation of the Model Context Protocol (MCP) transport layer using - * HTTP with Server-Sent Events (SSE) through Spring WebMVC. This implementation provides - * a bridge between synchronous WebMVC operations and reactive programming patterns to - * maintain compatibility with the reactive transport interface. - * - * @deprecated This class will be removed in 0.9.0. Use - * {@link WebMvcSseServerTransportProvider}. - * - *

- * Key features: - *

    - *
  • Implements bidirectional communication using HTTP POST for client-to-server - * messages and SSE for server-to-client messages
  • - *
  • Manages client sessions with unique IDs for reliable message delivery
  • - *
  • Supports graceful shutdown with proper session cleanup
  • - *
  • Provides JSON-RPC message handling through configured endpoints
  • - *
  • Includes built-in error handling and logging
  • - *
- * - *

- * The transport operates on two main endpoints: - *

    - *
  • {@code /sse} - The SSE endpoint where clients establish their event stream - * connection
  • - *
  • A configurable message endpoint where clients send their JSON-RPC messages via HTTP - * POST
  • - *
- * - *

- * This implementation uses {@link ConcurrentHashMap} to safely manage multiple client - * sessions in a thread-safe manner. Each client session is assigned a unique ID and - * maintains its own SSE connection. - * @author Christian Tzolov - * @author Alexandros Pappas - * @see ServerMcpTransport - * @see RouterFunction - */ -@Deprecated -public class WebMvcSseServerTransport implements ServerMcpTransport { - - private static final Logger logger = LoggerFactory.getLogger(WebMvcSseServerTransport.class); - - /** - * Event type for JSON-RPC messages sent through the SSE connection. - */ - public static final String MESSAGE_EVENT_TYPE = "message"; - - /** - * Event type for sending the message endpoint URI to clients. - */ - public static final String ENDPOINT_EVENT_TYPE = "endpoint"; - - /** - * Default SSE endpoint path as specified by the MCP transport specification. - */ - public static final String DEFAULT_SSE_ENDPOINT = "/sse"; - - private final ObjectMapper objectMapper; - - private final String messageEndpoint; - - private final String sseEndpoint; - - private final RouterFunction routerFunction; - - /** - * Map of active client sessions, keyed by session ID. - */ - private final ConcurrentHashMap sessions = new ConcurrentHashMap<>(); - - /** - * Flag indicating if the transport is shutting down. - */ - private volatile boolean isClosing = false; - - /** - * The function to process incoming JSON-RPC messages and produce responses. - */ - private Function, Mono> connectHandler; - - /** - * Constructs a new WebMvcSseServerTransport instance. - * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization - * of messages. - * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC - * messages via HTTP POST. This endpoint will be communicated to clients through the - * SSE connection's initial endpoint event. - * @throws IllegalArgumentException if either objectMapper or messageEndpoint is null - */ - public WebMvcSseServerTransport(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) { - Assert.notNull(objectMapper, "ObjectMapper must not be null"); - Assert.notNull(messageEndpoint, "Message endpoint must not be null"); - Assert.notNull(sseEndpoint, "SSE endpoint must not be null"); - - this.objectMapper = objectMapper; - this.messageEndpoint = messageEndpoint; - this.sseEndpoint = sseEndpoint; - this.routerFunction = RouterFunctions.route() - .GET(this.sseEndpoint, this::handleSseConnection) - .POST(this.messageEndpoint, this::handleMessage) - .build(); - } - - /** - * Constructs a new WebMvcSseServerTransport instance with the default SSE endpoint. - * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization - * of messages. - * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC - * messages via HTTP POST. This endpoint will be communicated to clients through the - * SSE connection's initial endpoint event. - * @throws IllegalArgumentException if either objectMapper or messageEndpoint is null - */ - public WebMvcSseServerTransport(ObjectMapper objectMapper, String messageEndpoint) { - this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT); - } - - /** - * Sets up the message handler for this transport. In the WebMVC SSE implementation, - * this method only stores the handler for later use, as connections are initiated by - * clients rather than the server. - * @param connectionHandler The function to process incoming JSON-RPC messages and - * produce responses - * @return An empty Mono since the server doesn't initiate connections - */ - @Override - public Mono connect( - Function, Mono> connectionHandler) { - this.connectHandler = connectionHandler; - // Server-side transport doesn't initiate connections - return Mono.empty(); - } - - /** - * Broadcasts a message to all connected clients through their SSE connections. The - * message is serialized to JSON and sent as an SSE event with type "message". If any - * errors occur during sending to a particular client, they are logged but don't - * prevent sending to other clients. - * @param message The JSON-RPC message to broadcast to all connected clients - * @return A Mono that completes when the broadcast attempt is finished - */ - @Override - public Mono sendMessage(McpSchema.JSONRPCMessage message) { - return Mono.fromRunnable(() -> { - if (sessions.isEmpty()) { - logger.debug("No active sessions to broadcast message to"); - return; - } - - try { - String jsonText = objectMapper.writeValueAsString(message); - logger.debug("Attempting to broadcast message to {} active sessions", sessions.size()); - - sessions.values().forEach(session -> { - try { - session.sseBuilder.id(session.id).event(MESSAGE_EVENT_TYPE).data(jsonText); - } - catch (Exception e) { - logger.error("Failed to send message to session {}: {}", session.id, e.getMessage()); - session.sseBuilder.error(e); - } - }); - } - catch (IOException e) { - logger.error("Failed to serialize message: {}", e.getMessage()); - } - }); - } - - /** - * Handles new SSE connection requests from clients by creating a new session and - * establishing an SSE connection. This method: - *

    - *
  • Generates a unique session ID
  • - *
  • Creates a new ClientSession with an SSE builder
  • - *
  • Sends an initial endpoint event to inform the client where to send - * messages
  • - *
  • Maintains the session in the sessions map
  • - *
- * @param request The incoming server request - * @return A ServerResponse configured for SSE communication, or an error response if - * the server is shutting down or the connection fails - */ - private ServerResponse handleSseConnection(ServerRequest request) { - if (this.isClosing) { - return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down"); - } - - String sessionId = UUID.randomUUID().toString(); - logger.debug("Creating new SSE connection for session: {}", sessionId); - - // Send initial endpoint event - try { - return ServerResponse.sse(sseBuilder -> { - sseBuilder.onComplete(() -> { - logger.debug("SSE connection completed for session: {}", sessionId); - sessions.remove(sessionId); - }); - sseBuilder.onTimeout(() -> { - logger.debug("SSE connection timed out for session: {}", sessionId); - sessions.remove(sessionId); - }); - - ClientSession session = new ClientSession(sessionId, sseBuilder); - this.sessions.put(sessionId, session); - - try { - session.sseBuilder.id(session.id).event(ENDPOINT_EVENT_TYPE).data(messageEndpoint); - } - catch (Exception e) { - logger.error("Failed to poll event from session queue: {}", e.getMessage()); - sseBuilder.error(e); - } - }, Duration.ZERO); - } - catch (Exception e) { - logger.error("Failed to send initial endpoint event to session {}: {}", sessionId, e.getMessage()); - sessions.remove(sessionId); - return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).build(); - } - } - - /** - * Handles incoming JSON-RPC messages from clients. This method: - *
    - *
  • Deserializes the request body into a JSON-RPC message
  • - *
  • Processes the message through the configured connect handler
  • - *
  • Returns appropriate HTTP responses based on the processing result
  • - *
- * @param request The incoming server request containing the JSON-RPC message - * @return A ServerResponse indicating success (200 OK) or appropriate error status - * with error details in case of failures - */ - private ServerResponse handleMessage(ServerRequest request) { - if (this.isClosing) { - return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down"); - } - - try { - String body = request.body(String.class); - McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body); - - // Convert the message to a Mono, apply the handler, and block for the - // response - @SuppressWarnings("unused") - McpSchema.JSONRPCMessage response = Mono.just(message).transform(connectHandler).block(); - - return ServerResponse.ok().build(); - } - catch (IllegalArgumentException | IOException e) { - logger.error("Failed to deserialize message: {}", e.getMessage()); - return ServerResponse.badRequest().body(new McpError("Invalid message format")); - } - catch (Exception e) { - logger.error("Error handling message: {}", e.getMessage()); - return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).body(new McpError(e.getMessage())); - } - } - - /** - * Represents an active client session with its associated SSE connection. Each - * session maintains: - *
    - *
  • A unique session identifier
  • - *
  • An SSE builder for sending server events to the client
  • - *
  • Logging of session lifecycle events
  • - *
- */ - private static class ClientSession { - - private final String id; - - private final SseBuilder sseBuilder; - - /** - * Creates a new client session with the specified ID and SSE builder. - * @param id The unique identifier for this session - * @param sseBuilder The SSE builder for sending server events to the client - */ - ClientSession(String id, SseBuilder sseBuilder) { - this.id = id; - this.sseBuilder = sseBuilder; - logger.debug("Session {} initialized with SSE emitter", id); - } - - /** - * Closes this session by completing the SSE connection. Any errors during - * completion are logged but do not prevent the session from being marked as - * closed. - */ - void close() { - logger.debug("Closing session: {}", id); - try { - sseBuilder.complete(); - logger.debug("Successfully completed SSE emitter for session {}", id); - } - catch (Exception e) { - logger.warn("Failed to complete SSE emitter for session {}: {}", id, e.getMessage()); - // sseBuilder.error(e); - } - } - - } - - /** - * Converts data from one type to another using the configured ObjectMapper. This is - * particularly useful for handling complex JSON-RPC parameter types. - * @param data The source data object to convert - * @param typeRef The target type reference - * @return The converted object of type T - * @param The target type - */ - @Override - public T unmarshalFrom(Object data, TypeReference typeRef) { - return this.objectMapper.convertValue(data, typeRef); - } - - /** - * Initiates a graceful shutdown of the transport. This method: - *
    - *
  • Sets the closing flag to prevent new connections
  • - *
  • Closes all active SSE connections
  • - *
  • Removes all session records
  • - *
- * @return A Mono that completes when all cleanup operations are finished - */ - @Override - public Mono closeGracefully() { - return Mono.fromRunnable(() -> { - this.isClosing = true; - logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size()); - - sessions.values().forEach(session -> { - String sessionId = session.id; - session.close(); - sessions.remove(sessionId); - }); - - logger.debug("Graceful shutdown completed"); - }); - } - - /** - * Returns the RouterFunction that defines the HTTP endpoints for this transport. The - * router function handles two endpoints: - *
    - *
  • GET /sse - For establishing SSE connections
  • - *
  • POST [messageEndpoint] - For receiving JSON-RPC messages from clients
  • - *
- * @return The configured RouterFunction for handling HTTP requests - */ - public RouterFunction getRouterFunction() { - return this.routerFunction; - } - -} diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportDeprecatedTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportDeprecatedTests.java deleted file mode 100644 index c3f0e3220..000000000 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportDeprecatedTests.java +++ /dev/null @@ -1,118 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.server; - -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.server.transport.WebMvcSseServerTransport; -import io.modelcontextprotocol.spec.ServerMcpTransport; -import org.apache.catalina.Context; -import org.apache.catalina.LifecycleException; -import org.apache.catalina.startup.Tomcat; -import org.junit.jupiter.api.Timeout; - -import org.springframework.context.annotation.Bean; -import org.springframework.context.annotation.Configuration; -import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; -import org.springframework.web.servlet.DispatcherServlet; -import org.springframework.web.servlet.config.annotation.EnableWebMvc; -import org.springframework.web.servlet.function.RouterFunction; -import org.springframework.web.servlet.function.ServerResponse; - -@Deprecated -@Timeout(15) -class WebMvcSseAsyncServerTransportDeprecatedTests extends AbstractMcpAsyncServerDeprecatedTests { - - private static final String MESSAGE_ENDPOINT = "/mcp/message"; - - private static final int PORT = 8181; - - private Tomcat tomcat; - - private WebMvcSseServerTransport transport; - - @Configuration - @EnableWebMvc - static class TestConfig { - - @Bean - public WebMvcSseServerTransport webMvcSseServerTransport() { - return new WebMvcSseServerTransport(new ObjectMapper(), MESSAGE_ENDPOINT); - } - - @Bean - public RouterFunction routerFunction(WebMvcSseServerTransport transport) { - return transport.getRouterFunction(); - } - - } - - private AnnotationConfigWebApplicationContext appContext; - - @Override - protected ServerMcpTransport createMcpTransport() { - // Set up Tomcat first - tomcat = new Tomcat(); - tomcat.setPort(PORT); - - // Set Tomcat base directory to java.io.tmpdir to avoid permission issues - String baseDir = System.getProperty("java.io.tmpdir"); - tomcat.setBaseDir(baseDir); - - // Use the same directory for document base - Context context = tomcat.addContext("", baseDir); - - // Create and configure Spring WebMvc context - appContext = new AnnotationConfigWebApplicationContext(); - appContext.register(TestConfig.class); - appContext.setServletContext(context.getServletContext()); - appContext.refresh(); - - // Get the transport from Spring context - transport = appContext.getBean(WebMvcSseServerTransport.class); - - // Create DispatcherServlet with our Spring context - DispatcherServlet dispatcherServlet = new DispatcherServlet(appContext); - // dispatcherServlet.setThrowExceptionIfNoHandlerFound(true); - - // Add servlet to Tomcat and get the wrapper - var wrapper = Tomcat.addServlet(context, "dispatcherServlet", dispatcherServlet); - wrapper.setLoadOnStartup(1); - context.addServletMappingDecoded("/*", "dispatcherServlet"); - - try { - tomcat.start(); - tomcat.getConnector(); // Create and start the connector - } - catch (LifecycleException e) { - throw new RuntimeException("Failed to start Tomcat", e); - } - - return transport; - } - - @Override - protected void onStart() { - } - - @Override - protected void onClose() { - if (transport != null) { - transport.closeGracefully().block(); - } - if (appContext != null) { - appContext.close(); - } - if (tomcat != null) { - try { - tomcat.stop(); - tomcat.destroy(); - } - catch (LifecycleException e) { - throw new RuntimeException("Failed to stop Tomcat", e); - } - } - } - -} diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationDeprecatedTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationDeprecatedTests.java deleted file mode 100644 index f2b593d8d..000000000 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationDeprecatedTests.java +++ /dev/null @@ -1,508 +0,0 @@ -/* - * Copyright 2024 - 2024 the original author or authors. - */ -package io.modelcontextprotocol.server; - -import java.time.Duration; -import java.util.List; -import java.util.Map; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Function; - -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.client.McpClient; -import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; -import io.modelcontextprotocol.server.transport.WebMvcSseServerTransport; -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpSchema.CallToolResult; -import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; -import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; -import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; -import io.modelcontextprotocol.spec.McpSchema.InitializeResult; -import io.modelcontextprotocol.spec.McpSchema.Role; -import io.modelcontextprotocol.spec.McpSchema.Root; -import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; -import io.modelcontextprotocol.spec.McpSchema.Tool; -import org.apache.catalina.Context; -import org.apache.catalina.LifecycleException; -import org.apache.catalina.LifecycleState; -import org.apache.catalina.startup.Tomcat; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import reactor.test.StepVerifier; - -import org.springframework.context.annotation.Bean; -import org.springframework.context.annotation.Configuration; -import org.springframework.web.client.RestClient; -import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; -import org.springframework.web.servlet.DispatcherServlet; -import org.springframework.web.servlet.config.annotation.EnableWebMvc; -import org.springframework.web.servlet.function.RouterFunction; -import org.springframework.web.servlet.function.ServerResponse; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.awaitility.Awaitility.await; - -@Deprecated -public class WebMvcSseIntegrationDeprecatedTests { - - private static final int PORT = 8183; - - private static final String MESSAGE_ENDPOINT = "/mcp/message"; - - private WebMvcSseServerTransport mcpServerTransport; - - McpClient.SyncSpec clientBuilder; - - @Configuration - @EnableWebMvc - static class TestConfig { - - @Bean - public WebMvcSseServerTransport webMvcSseServerTransport() { - return new WebMvcSseServerTransport(new ObjectMapper(), MESSAGE_ENDPOINT); - } - - @Bean - public RouterFunction routerFunction(WebMvcSseServerTransport transport) { - return transport.getRouterFunction(); - } - - } - - private Tomcat tomcat; - - private AnnotationConfigWebApplicationContext appContext; - - @BeforeEach - public void before() { - - // Set up Tomcat first - tomcat = new Tomcat(); - tomcat.setPort(PORT); - - // Set Tomcat base directory to java.io.tmpdir to avoid permission issues - String baseDir = System.getProperty("java.io.tmpdir"); - tomcat.setBaseDir(baseDir); - - // Use the same directory for document base - Context context = tomcat.addContext("", baseDir); - - // Create and configure Spring WebMvc context - appContext = new AnnotationConfigWebApplicationContext(); - appContext.register(TestConfig.class); - appContext.setServletContext(context.getServletContext()); - appContext.refresh(); - - // Get the transport from Spring context - mcpServerTransport = appContext.getBean(WebMvcSseServerTransport.class); - - // Create DispatcherServlet with our Spring context - DispatcherServlet dispatcherServlet = new DispatcherServlet(appContext); - // dispatcherServlet.setThrowExceptionIfNoHandlerFound(true); - - // Add servlet to Tomcat and get the wrapper - var wrapper = Tomcat.addServlet(context, "dispatcherServlet", dispatcherServlet); - wrapper.setLoadOnStartup(1); - wrapper.setAsyncSupported(true); - context.addServletMappingDecoded("/*", "dispatcherServlet"); - - try { - // Configure and start the connector with async support - var connector = tomcat.getConnector(); - connector.setAsyncTimeout(3000); // 3 seconds timeout for async requests - tomcat.start(); - assertThat(tomcat.getServer().getState() == LifecycleState.STARTED); - } - catch (Exception e) { - throw new RuntimeException("Failed to start Tomcat", e); - } - - this.clientBuilder = McpClient.sync(new HttpClientSseClientTransport("http://localhost:" + PORT)); - } - - @AfterEach - public void after() { - if (mcpServerTransport != null) { - mcpServerTransport.closeGracefully().block(); - } - if (appContext != null) { - appContext.close(); - } - if (tomcat != null) { - try { - tomcat.stop(); - tomcat.destroy(); - } - catch (LifecycleException e) { - throw new RuntimeException("Failed to stop Tomcat", e); - } - } - } - - // --------------------------------------- - // Sampling Tests - // --------------------------------------- - @Test - void testCreateMessageWithoutInitialization() { - var mcpAsyncServer = McpServer.async(mcpServerTransport).serverInfo("test-server", "1.0.0").build(); - - var messages = List - .of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Test message"))); - var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); - - var request = new McpSchema.CreateMessageRequest(messages, modelPrefs, null, - McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), Map.of()); - - StepVerifier.create(mcpAsyncServer.createMessage(request)).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized. Call the initialize method first!"); - }); - } - - @Test - void testCreateMessageWithoutSamplingCapabilities() { - - var mcpAsyncServer = McpServer.async(mcpServerTransport).serverInfo("test-server", "1.0.0").build(); - - var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")).build(); - - InitializeResult initResult = client.initialize(); - assertThat(initResult).isNotNull(); - - var messages = List - .of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Test message"))); - var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); - - var request = new McpSchema.CreateMessageRequest(messages, modelPrefs, null, - McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), Map.of()); - - StepVerifier.create(mcpAsyncServer.createMessage(request)).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Client must be configured with sampling capabilities"); - }); - } - - @Test - void testCreateMessageSuccess() throws InterruptedException { - - var mcpAsyncServer = McpServer.async(mcpServerTransport).serverInfo("test-server", "1.0.0").build(); - - Function samplingHandler = request -> { - assertThat(request.messages()).hasSize(1); - assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); - - return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", - CreateMessageResult.StopReason.STOP_SEQUENCE); - }; - - var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().sampling().build()) - .sampling(samplingHandler) - .build(); - - InitializeResult initResult = client.initialize(); - assertThat(initResult).isNotNull(); - - var messages = List - .of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Test message"))); - var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); - - var request = new McpSchema.CreateMessageRequest(messages, modelPrefs, null, - McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), Map.of()); - - StepVerifier.create(mcpAsyncServer.createMessage(request)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.role()).isEqualTo(Role.USER); - assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); - assertThat(result.model()).isEqualTo("MockModelName"); - assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); - }).verifyComplete(); - } - - // --------------------------------------- - // Roots Tests - // --------------------------------------- - @Test - void testRootsSuccess() { - List roots = List.of(new Root("uri1://", "root1"), new Root("uri2://", "root2")); - - AtomicReference> rootsRef = new AtomicReference<>(); - var mcpServer = McpServer.sync(mcpServerTransport) - .rootsChangeConsumer(rootsUpdate -> rootsRef.set(rootsUpdate)) - .build(); - - var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) - .roots(roots) - .build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThat(rootsRef.get()).isNull(); - - assertThat(mcpServer.listRoots().roots()).containsAll(roots); - - mcpClient.rootsListChangedNotification(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(roots); - }); - - // Remove a root - mcpClient.removeRoot(roots.get(0).uri()); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(roots.get(1))); - }); - - // Add a new root - var root3 = new Root("uri3://", "root3"); - mcpClient.addRoot(root3); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(roots.get(1), root3)); - }); - - mcpClient.close(); - mcpServer.close(); - } - - @Test - void testRootsWithoutCapability() { - var mcpServer = McpServer.sync(mcpServerTransport).rootsChangeConsumer(rootsUpdate -> { - }).build(); - - // Create client without roots capability - // No roots capability - var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()).build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - // Attempt to list roots should fail - assertThatThrownBy(() -> mcpServer.listRoots().roots()).isInstanceOf(McpError.class) - .hasMessage("Roots not supported"); - - mcpClient.close(); - mcpServer.close(); - } - - @Test - void testRootsWithEmptyRootsList() { - AtomicReference> rootsRef = new AtomicReference<>(); - var mcpServer = McpServer.sync(mcpServerTransport) - .rootsChangeConsumer(rootsUpdate -> rootsRef.set(rootsUpdate)) - .build(); - - var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) - .roots(List.of()) // Empty roots list - .build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - mcpClient.rootsListChangedNotification(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).isEmpty(); - }); - - mcpClient.close(); - mcpServer.close(); - } - - @Test - void testRootsWithMultipleConsumers() { - List roots = List.of(new Root("uri1://", "root1")); - - AtomicReference> rootsRef1 = new AtomicReference<>(); - AtomicReference> rootsRef2 = new AtomicReference<>(); - - var mcpServer = McpServer.sync(mcpServerTransport) - .rootsChangeConsumer(rootsUpdate -> rootsRef1.set(rootsUpdate)) - .rootsChangeConsumer(rootsUpdate -> rootsRef2.set(rootsUpdate)) - .build(); - - var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) - .roots(roots) - .build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - mcpClient.rootsListChangedNotification(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef1.get()).containsAll(roots); - assertThat(rootsRef2.get()).containsAll(roots); - }); - - mcpClient.close(); - mcpServer.close(); - } - - @Test - void testRootsServerCloseWithActiveSubscription() { - List roots = List.of(new Root("uri1://", "root1")); - - AtomicReference> rootsRef = new AtomicReference<>(); - var mcpServer = McpServer.sync(mcpServerTransport) - .rootsChangeConsumer(rootsUpdate -> rootsRef.set(rootsUpdate)) - .build(); - - var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) - .roots(roots) - .build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - mcpClient.rootsListChangedNotification(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(roots); - }); - - // Close server while subscription is active - mcpServer.close(); - - // Verify client can handle server closure gracefully - mcpClient.close(); - } - - // --------------------------------------- - // Tools Tests - // --------------------------------------- - - String emptyJsonSchema = """ - { - "$schema": "http://json-schema.org/draft-07/schema#", - "type": "object", - "properties": {} - } - """; - - @Test - void testToolCallSuccess() { - - var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); - McpServerFeatures.SyncToolRegistration tool1 = new McpServerFeatures.SyncToolRegistration( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), request -> { - // perform a blocking call to a remote service - String response = RestClient.create() - .get() - .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") - .retrieve() - .body(String.class); - assertThat(response).isNotBlank(); - return callResponse; - }); - - var mcpServer = McpServer.sync(mcpServerTransport) - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool1) - .build(); - - var mcpClient = clientBuilder.build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); - - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); - - mcpClient.close(); - mcpServer.close(); - } - - @Test - void testToolListChangeHandlingSuccess() { - - var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); - McpServerFeatures.SyncToolRegistration tool1 = new McpServerFeatures.SyncToolRegistration( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), request -> { - // perform a blocking call to a remote service - String response = RestClient.create() - .get() - .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") - .retrieve() - .body(String.class); - assertThat(response).isNotBlank(); - return callResponse; - }); - - var mcpServer = McpServer.sync(mcpServerTransport) - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool1) - .build(); - - AtomicReference> rootsRef = new AtomicReference<>(); - var mcpClient = clientBuilder.toolsChangeConsumer(toolsUpdate -> { - // perform a blocking call to a remote service - String response = RestClient.create() - .get() - .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") - .retrieve() - .body(String.class); - assertThat(response).isNotBlank(); - rootsRef.set(toolsUpdate); - }).build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThat(rootsRef.get()).isNull(); - - assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); - - mcpServer.notifyToolsListChanged(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(tool1.tool())); - }); - - // Remove a tool - mcpServer.removeTool("tool1"); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).isEmpty(); - }); - - // Add a new tool - McpServerFeatures.SyncToolRegistration tool2 = new McpServerFeatures.SyncToolRegistration( - new McpSchema.Tool("tool2", "tool2 description", emptyJsonSchema), request -> callResponse); - - mcpServer.addTool(tool2); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(tool2.tool())); - }); - - mcpClient.close(); - mcpServer.close(); - } - - @Test - void testInitialize() { - - var mcpServer = McpServer.sync(mcpServerTransport).build(); - - var mcpClient = clientBuilder.build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - mcpClient.close(); - mcpServer.close(); - } - -} diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseSyncServerTransportDeprecatedTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseSyncServerTransportDeprecatedTests.java deleted file mode 100644 index 8656665ed..000000000 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseSyncServerTransportDeprecatedTests.java +++ /dev/null @@ -1,118 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.server; - -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.server.transport.WebMvcSseServerTransport; -import io.modelcontextprotocol.spec.ServerMcpTransport; -import org.apache.catalina.Context; -import org.apache.catalina.LifecycleException; -import org.apache.catalina.startup.Tomcat; -import org.junit.jupiter.api.Timeout; - -import org.springframework.context.annotation.Bean; -import org.springframework.context.annotation.Configuration; -import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; -import org.springframework.web.servlet.DispatcherServlet; -import org.springframework.web.servlet.config.annotation.EnableWebMvc; -import org.springframework.web.servlet.function.RouterFunction; -import org.springframework.web.servlet.function.ServerResponse; - -@Deprecated -@Timeout(15) -class WebMvcSseSyncServerTransportDeprecatedTests extends AbstractMcpSyncServerDeprecatedTests { - - private static final String MESSAGE_ENDPOINT = "/mcp/message"; - - private static final int PORT = 8181; - - private Tomcat tomcat; - - private WebMvcSseServerTransport transport; - - @Configuration - @EnableWebMvc - static class TestConfig { - - @Bean - public WebMvcSseServerTransport webMvcSseServerTransport() { - return new WebMvcSseServerTransport(new ObjectMapper(), MESSAGE_ENDPOINT); - } - - @Bean - public RouterFunction routerFunction(WebMvcSseServerTransport transport) { - return transport.getRouterFunction(); - } - - } - - private AnnotationConfigWebApplicationContext appContext; - - @Override - protected ServerMcpTransport createMcpTransport() { - // Set up Tomcat first - tomcat = new Tomcat(); - tomcat.setPort(PORT); - - // Set Tomcat base directory to java.io.tmpdir to avoid permission issues - String baseDir = System.getProperty("java.io.tmpdir"); - tomcat.setBaseDir(baseDir); - - // Use the same directory for document base - Context context = tomcat.addContext("", baseDir); - - // Create and configure Spring WebMvc context - appContext = new AnnotationConfigWebApplicationContext(); - appContext.register(TestConfig.class); - appContext.setServletContext(context.getServletContext()); - appContext.refresh(); - - // Get the transport from Spring context - transport = appContext.getBean(WebMvcSseServerTransport.class); - - // Create DispatcherServlet with our Spring context - DispatcherServlet dispatcherServlet = new DispatcherServlet(appContext); - // dispatcherServlet.setThrowExceptionIfNoHandlerFound(true); - - // Add servlet to Tomcat and get the wrapper - var wrapper = Tomcat.addServlet(context, "dispatcherServlet", dispatcherServlet); - wrapper.setLoadOnStartup(1); - context.addServletMappingDecoded("/*", "dispatcherServlet"); - - try { - tomcat.start(); - tomcat.getConnector(); // Create and start the connector - } - catch (LifecycleException e) { - throw new RuntimeException("Failed to start Tomcat", e); - } - - return transport; - } - - @Override - protected void onStart() { - } - - @Override - protected void onClose() { - if (transport != null) { - transport.closeGracefully().block(); - } - if (appContext != null) { - appContext.close(); - } - if (tomcat != null) { - try { - tomcat.stop(); - tomcat.destroy(); - } - catch (LifecycleException e) { - throw new RuntimeException("Failed to stop Tomcat", e); - } - } - } - -} diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/MockMcpTransport.java b/mcp-test/src/main/java/io/modelcontextprotocol/MockMcpTransport.java index cef3fb9fa..5484a63c2 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/MockMcpTransport.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/MockMcpTransport.java @@ -15,15 +15,18 @@ import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.JSONRPCNotification; import io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest; -import io.modelcontextprotocol.spec.ServerMcpTransport; +import io.modelcontextprotocol.spec.McpServerTransport; import reactor.core.publisher.Mono; import reactor.core.publisher.Sinks; /** - * A mock implementation of the {@link McpClientTransport} and {@link ServerMcpTransport} + * A mock implementation of the {@link McpClientTransport} and {@link McpServerTransport} * interfaces. + * + * @deprecated not used. to be removed in the future. */ -public class MockMcpTransport implements McpClientTransport, ServerMcpTransport { +@Deprecated +public class MockMcpTransport implements McpClientTransport, McpServerTransport { private final Sinks.Many inbound = Sinks.many().unicast().onBackpressureBuffer(); diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerDeprecatedTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerDeprecatedTests.java deleted file mode 100644 index 005d78f25..000000000 --- a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerDeprecatedTests.java +++ /dev/null @@ -1,465 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.server; - -import java.time.Duration; -import java.util.List; - -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpSchema.CallToolResult; -import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; -import io.modelcontextprotocol.spec.McpSchema.Prompt; -import io.modelcontextprotocol.spec.McpSchema.PromptMessage; -import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; -import io.modelcontextprotocol.spec.McpSchema.Resource; -import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; -import io.modelcontextprotocol.spec.McpSchema.Tool; -import io.modelcontextprotocol.spec.McpTransport; -import io.modelcontextprotocol.spec.ServerMcpTransport; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import reactor.core.publisher.Mono; -import reactor.test.StepVerifier; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; -import static org.assertj.core.api.Assertions.assertThatThrownBy; - -/** - * Test suite for the {@link McpAsyncServer} that can be used with different - * {@link McpTransport} implementations. - * - * @author Christian Tzolov - */ -@Deprecated -public abstract class AbstractMcpAsyncServerDeprecatedTests { - - private static final String TEST_TOOL_NAME = "test-tool"; - - private static final String TEST_RESOURCE_URI = "test://resource"; - - private static final String TEST_PROMPT_NAME = "test-prompt"; - - abstract protected ServerMcpTransport createMcpTransport(); - - protected void onStart() { - } - - protected void onClose() { - } - - @BeforeEach - void setUp() { - } - - @AfterEach - void tearDown() { - onClose(); - } - - // --------------------------------------- - // Server Lifecycle Tests - // --------------------------------------- - - @Test - void testConstructorWithInvalidArguments() { - assertThatThrownBy(() -> McpServer.async((ServerMcpTransport) null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("Transport must not be null"); - - assertThatThrownBy(() -> McpServer.async(createMcpTransport()).serverInfo((McpSchema.Implementation) null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("Server info must not be null"); - } - - @Test - void testGracefulShutdown() { - var mcpAsyncServer = McpServer.async(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); - - StepVerifier.create(mcpAsyncServer.closeGracefully()).verifyComplete(); - } - - @Test - void testImmediateClose() { - var mcpAsyncServer = McpServer.async(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); - - assertThatCode(() -> mcpAsyncServer.close()).doesNotThrowAnyException(); - } - - // --------------------------------------- - // Tools Tests - // --------------------------------------- - String emptyJsonSchema = """ - { - "$schema": "http://json-schema.org/draft-07/schema#", - "type": "object", - "properties": {} - } - """; - - @Test - void testAddTool() { - Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema); - var mcpAsyncServer = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .build(); - - StepVerifier.create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolRegistration(newTool, - args -> Mono.just(new CallToolResult(List.of(), false))))) - .verifyComplete(); - - assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - } - - @Test - void testAddDuplicateTool() { - Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - - var mcpAsyncServer = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(duplicateTool, args -> Mono.just(new CallToolResult(List.of(), false))) - .build(); - - StepVerifier.create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolRegistration(duplicateTool, - args -> Mono.just(new CallToolResult(List.of(), false))))) - .verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Tool with name '" + TEST_TOOL_NAME + "' already exists"); - }); - - assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - } - - @Test - void testRemoveTool() { - Tool too = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - - var mcpAsyncServer = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(too, args -> Mono.just(new CallToolResult(List.of(), false))) - .build(); - - StepVerifier.create(mcpAsyncServer.removeTool(TEST_TOOL_NAME)).verifyComplete(); - - assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - } - - @Test - void testRemoveNonexistentTool() { - var mcpAsyncServer = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .build(); - - StepVerifier.create(mcpAsyncServer.removeTool("nonexistent-tool")).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class).hasMessage("Tool with name 'nonexistent-tool' not found"); - }); - - assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - } - - @Test - void testNotifyToolsListChanged() { - Tool too = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - - var mcpAsyncServer = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(too, args -> Mono.just(new CallToolResult(List.of(), false))) - .build(); - - StepVerifier.create(mcpAsyncServer.notifyToolsListChanged()).verifyComplete(); - - assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - } - - // --------------------------------------- - // Resources Tests - // --------------------------------------- - - @Test - void testNotifyResourcesListChanged() { - var mcpAsyncServer = McpServer.async(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); - - StepVerifier.create(mcpAsyncServer.notifyResourcesListChanged()).verifyComplete(); - - assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - } - - @Test - void testAddResource() { - var mcpAsyncServer = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().resources(true, false).build()) - .build(); - - Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", - null); - McpServerFeatures.AsyncResourceRegistration registration = new McpServerFeatures.AsyncResourceRegistration( - resource, req -> Mono.just(new ReadResourceResult(List.of()))); - - StepVerifier.create(mcpAsyncServer.addResource(registration)).verifyComplete(); - - assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - } - - @Test - void testAddResourceWithNullRegistration() { - var mcpAsyncServer = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().resources(true, false).build()) - .build(); - - StepVerifier.create(mcpAsyncServer.addResource((McpServerFeatures.AsyncResourceRegistration) null)) - .verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class).hasMessage("Resource must not be null"); - }); - - assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - } - - @Test - void testAddResourceWithoutCapability() { - // Create a server without resource capabilities - McpAsyncServer serverWithoutResources = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .build(); - - Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", - null); - McpServerFeatures.AsyncResourceRegistration registration = new McpServerFeatures.AsyncResourceRegistration( - resource, req -> Mono.just(new ReadResourceResult(List.of()))); - - StepVerifier.create(serverWithoutResources.addResource(registration)).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Server must be configured with resource capabilities"); - }); - } - - @Test - void testRemoveResourceWithoutCapability() { - // Create a server without resource capabilities - McpAsyncServer serverWithoutResources = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .build(); - - StepVerifier.create(serverWithoutResources.removeResource(TEST_RESOURCE_URI)).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Server must be configured with resource capabilities"); - }); - } - - // --------------------------------------- - // Prompts Tests - // --------------------------------------- - - @Test - void testNotifyPromptsListChanged() { - var mcpAsyncServer = McpServer.async(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); - - StepVerifier.create(mcpAsyncServer.notifyPromptsListChanged()).verifyComplete(); - - assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - } - - @Test - void testAddPromptWithNullRegistration() { - var mcpAsyncServer = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().prompts(false).build()) - .build(); - - StepVerifier.create(mcpAsyncServer.addPrompt((McpServerFeatures.AsyncPromptRegistration) null)) - .verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class).hasMessage("Prompt registration must not be null"); - }); - } - - @Test - void testAddPromptWithoutCapability() { - // Create a server without prompt capabilities - McpAsyncServer serverWithoutPrompts = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .build(); - - Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", List.of()); - McpServerFeatures.AsyncPromptRegistration registration = new McpServerFeatures.AsyncPromptRegistration(prompt, - req -> Mono.just(new GetPromptResult("Test prompt description", List - .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content")))))); - - StepVerifier.create(serverWithoutPrompts.addPrompt(registration)).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Server must be configured with prompt capabilities"); - }); - } - - @Test - void testRemovePromptWithoutCapability() { - // Create a server without prompt capabilities - McpAsyncServer serverWithoutPrompts = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .build(); - - StepVerifier.create(serverWithoutPrompts.removePrompt(TEST_PROMPT_NAME)).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Server must be configured with prompt capabilities"); - }); - } - - @Test - void testRemovePrompt() { - String TEST_PROMPT_NAME_TO_REMOVE = "TEST_PROMPT_NAME678"; - - Prompt prompt = new Prompt(TEST_PROMPT_NAME_TO_REMOVE, "Test Prompt", List.of()); - McpServerFeatures.AsyncPromptRegistration registration = new McpServerFeatures.AsyncPromptRegistration(prompt, - req -> Mono.just(new GetPromptResult("Test prompt description", List - .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content")))))); - - var mcpAsyncServer = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().prompts(true).build()) - .prompts(registration) - .build(); - - StepVerifier.create(mcpAsyncServer.removePrompt(TEST_PROMPT_NAME_TO_REMOVE)).verifyComplete(); - - assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - } - - @Test - void testRemoveNonexistentPrompt() { - var mcpAsyncServer2 = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().prompts(true).build()) - .build(); - - StepVerifier.create(mcpAsyncServer2.removePrompt("nonexistent-prompt")).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Prompt with name 'nonexistent-prompt' not found"); - }); - - assertThatCode(() -> mcpAsyncServer2.closeGracefully().block(Duration.ofSeconds(10))) - .doesNotThrowAnyException(); - } - - // --------------------------------------- - // Roots Tests - // --------------------------------------- - - @Test - void testRootsChangeConsumers() { - // Test with single consumer - var rootsReceived = new McpSchema.Root[1]; - var consumerCalled = new boolean[1]; - - var singleConsumerServer = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .rootsChangeConsumers(List.of(roots -> Mono.fromRunnable(() -> { - consumerCalled[0] = true; - if (!roots.isEmpty()) { - rootsReceived[0] = roots.get(0); - } - }))) - .build(); - - assertThat(singleConsumerServer).isNotNull(); - assertThatCode(() -> singleConsumerServer.closeGracefully().block(Duration.ofSeconds(10))) - .doesNotThrowAnyException(); - onClose(); - - // Test with multiple consumers - var consumer1Called = new boolean[1]; - var consumer2Called = new boolean[1]; - var rootsContent = new List[1]; - - var multipleConsumersServer = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .rootsChangeConsumers(List.of(roots -> Mono.fromRunnable(() -> { - consumer1Called[0] = true; - rootsContent[0] = roots; - }), roots -> Mono.fromRunnable(() -> consumer2Called[0] = true))) - .build(); - - assertThat(multipleConsumersServer).isNotNull(); - assertThatCode(() -> multipleConsumersServer.closeGracefully().block(Duration.ofSeconds(10))) - .doesNotThrowAnyException(); - onClose(); - - // Test error handling - var errorHandlingServer = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .rootsChangeConsumers(List.of(roots -> { - throw new RuntimeException("Test error"); - })) - .build(); - - assertThat(errorHandlingServer).isNotNull(); - assertThatCode(() -> errorHandlingServer.closeGracefully().block(Duration.ofSeconds(10))) - .doesNotThrowAnyException(); - onClose(); - - // Test without consumers - var noConsumersServer = McpServer.async(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); - - assertThat(noConsumersServer).isNotNull(); - assertThatCode(() -> noConsumersServer.closeGracefully().block(Duration.ofSeconds(10))) - .doesNotThrowAnyException(); - } - - // --------------------------------------- - // Logging Tests - // --------------------------------------- - - @Test - void testLoggingLevels() { - var mcpAsyncServer = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().logging().build()) - .build(); - - // Test all logging levels - for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { - var notification = McpSchema.LoggingMessageNotification.builder() - .level(level) - .logger("test-logger") - .data("Test message with level " + level) - .build(); - - StepVerifier.create(mcpAsyncServer.loggingNotification(notification)).verifyComplete(); - } - } - - @Test - void testLoggingWithoutCapability() { - var mcpAsyncServer = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().build()) // No logging capability - .build(); - - var notification = McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.INFO) - .logger("test-logger") - .data("Test log message") - .build(); - - StepVerifier.create(mcpAsyncServer.loggingNotification(notification)).verifyComplete(); - } - - @Test - void testLoggingWithNullNotification() { - var mcpAsyncServer = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().logging().build()) - .build(); - - StepVerifier.create(mcpAsyncServer.loggingNotification(null)).verifyError(McpError.class); - } - -} diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerDeprecatedTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerDeprecatedTests.java deleted file mode 100644 index c6625acaa..000000000 --- a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerDeprecatedTests.java +++ /dev/null @@ -1,431 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.server; - -import java.util.List; - -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpSchema.CallToolResult; -import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; -import io.modelcontextprotocol.spec.McpSchema.Prompt; -import io.modelcontextprotocol.spec.McpSchema.PromptMessage; -import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; -import io.modelcontextprotocol.spec.McpSchema.Resource; -import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; -import io.modelcontextprotocol.spec.McpSchema.Tool; -import io.modelcontextprotocol.spec.McpTransport; -import io.modelcontextprotocol.spec.ServerMcpTransport; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; -import static org.assertj.core.api.Assertions.assertThatThrownBy; - -/** - * Test suite for the {@link McpSyncServer} that can be used with different - * {@link McpTransport} implementations. - * - * @author Christian Tzolov - */ -public abstract class AbstractMcpSyncServerDeprecatedTests { - - private static final String TEST_TOOL_NAME = "test-tool"; - - private static final String TEST_RESOURCE_URI = "test://resource"; - - private static final String TEST_PROMPT_NAME = "test-prompt"; - - abstract protected ServerMcpTransport createMcpTransport(); - - protected void onStart() { - } - - protected void onClose() { - } - - @BeforeEach - void setUp() { - // onStart(); - } - - @AfterEach - void tearDown() { - onClose(); - } - - // --------------------------------------- - // Server Lifecycle Tests - // --------------------------------------- - - @Test - void testConstructorWithInvalidArguments() { - assertThatThrownBy(() -> McpServer.sync((ServerMcpTransport) null)).isInstanceOf(IllegalArgumentException.class) - .hasMessage("Transport must not be null"); - - assertThatThrownBy(() -> McpServer.sync(createMcpTransport()).serverInfo(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("Server info must not be null"); - } - - @Test - void testGracefulShutdown() { - var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - @Test - void testImmediateClose() { - var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); - - assertThatCode(() -> mcpSyncServer.close()).doesNotThrowAnyException(); - } - - @Test - void testGetAsyncServer() { - var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); - - assertThat(mcpSyncServer.getAsyncServer()).isNotNull(); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - // --------------------------------------- - // Tools Tests - // --------------------------------------- - - String emptyJsonSchema = """ - { - "$schema": "http://json-schema.org/draft-07/schema#", - "type": "object", - "properties": {} - } - """; - - @Test - void testAddTool() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .build(); - - Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema); - assertThatCode(() -> mcpSyncServer - .addTool(new McpServerFeatures.SyncToolRegistration(newTool, args -> new CallToolResult(List.of(), false)))) - .doesNotThrowAnyException(); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - @Test - void testAddDuplicateTool() { - Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - - var mcpSyncServer = McpServer.sync(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(duplicateTool, args -> new CallToolResult(List.of(), false)) - .build(); - - assertThatThrownBy(() -> mcpSyncServer.addTool(new McpServerFeatures.SyncToolRegistration(duplicateTool, - args -> new CallToolResult(List.of(), false)))) - .isInstanceOf(McpError.class) - .hasMessage("Tool with name '" + TEST_TOOL_NAME + "' already exists"); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - @Test - void testRemoveTool() { - Tool tool = new McpSchema.Tool(TEST_TOOL_NAME, "Test tool", emptyJsonSchema); - - var mcpSyncServer = McpServer.sync(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(tool, args -> new CallToolResult(List.of(), false)) - .build(); - - assertThatCode(() -> mcpSyncServer.removeTool(TEST_TOOL_NAME)).doesNotThrowAnyException(); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - @Test - void testRemoveNonexistentTool() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .build(); - - assertThatThrownBy(() -> mcpSyncServer.removeTool("nonexistent-tool")).isInstanceOf(McpError.class) - .hasMessage("Tool with name 'nonexistent-tool' not found"); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - @Test - void testNotifyToolsListChanged() { - var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); - - assertThatCode(() -> mcpSyncServer.notifyToolsListChanged()).doesNotThrowAnyException(); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - // --------------------------------------- - // Resources Tests - // --------------------------------------- - - @Test - void testNotifyResourcesListChanged() { - var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); - - assertThatCode(() -> mcpSyncServer.notifyResourcesListChanged()).doesNotThrowAnyException(); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - @Test - void testAddResource() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().resources(true, false).build()) - .build(); - - Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", - null); - McpServerFeatures.SyncResourceRegistration registration = new McpServerFeatures.SyncResourceRegistration( - resource, req -> new ReadResourceResult(List.of())); - - assertThatCode(() -> mcpSyncServer.addResource(registration)).doesNotThrowAnyException(); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - @Test - void testAddResourceWithNullRegistration() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().resources(true, false).build()) - .build(); - - assertThatThrownBy(() -> mcpSyncServer.addResource((McpServerFeatures.SyncResourceRegistration) null)) - .isInstanceOf(McpError.class) - .hasMessage("Resource must not be null"); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - @Test - void testAddResourceWithoutCapability() { - var serverWithoutResources = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); - - Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", - null); - McpServerFeatures.SyncResourceRegistration registration = new McpServerFeatures.SyncResourceRegistration( - resource, req -> new ReadResourceResult(List.of())); - - assertThatThrownBy(() -> serverWithoutResources.addResource(registration)).isInstanceOf(McpError.class) - .hasMessage("Server must be configured with resource capabilities"); - } - - @Test - void testRemoveResourceWithoutCapability() { - var serverWithoutResources = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); - - assertThatThrownBy(() -> serverWithoutResources.removeResource(TEST_RESOURCE_URI)).isInstanceOf(McpError.class) - .hasMessage("Server must be configured with resource capabilities"); - } - - // --------------------------------------- - // Prompts Tests - // --------------------------------------- - - @Test - void testNotifyPromptsListChanged() { - var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); - - assertThatCode(() -> mcpSyncServer.notifyPromptsListChanged()).doesNotThrowAnyException(); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - @Test - void testAddPromptWithNullRegistration() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().prompts(false).build()) - .build(); - - assertThatThrownBy(() -> mcpSyncServer.addPrompt((McpServerFeatures.SyncPromptRegistration) null)) - .isInstanceOf(McpError.class) - .hasMessage("Prompt registration must not be null"); - } - - @Test - void testAddPromptWithoutCapability() { - var serverWithoutPrompts = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); - - Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", List.of()); - McpServerFeatures.SyncPromptRegistration registration = new McpServerFeatures.SyncPromptRegistration(prompt, - req -> new GetPromptResult("Test prompt description", List - .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content"))))); - - assertThatThrownBy(() -> serverWithoutPrompts.addPrompt(registration)).isInstanceOf(McpError.class) - .hasMessage("Server must be configured with prompt capabilities"); - } - - @Test - void testRemovePromptWithoutCapability() { - var serverWithoutPrompts = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); - - assertThatThrownBy(() -> serverWithoutPrompts.removePrompt(TEST_PROMPT_NAME)).isInstanceOf(McpError.class) - .hasMessage("Server must be configured with prompt capabilities"); - } - - @Test - void testRemovePrompt() { - Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", List.of()); - McpServerFeatures.SyncPromptRegistration registration = new McpServerFeatures.SyncPromptRegistration(prompt, - req -> new GetPromptResult("Test prompt description", List - .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content"))))); - - var mcpSyncServer = McpServer.sync(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().prompts(true).build()) - .prompts(registration) - .build(); - - assertThatCode(() -> mcpSyncServer.removePrompt(TEST_PROMPT_NAME)).doesNotThrowAnyException(); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - @Test - void testRemoveNonexistentPrompt() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().prompts(true).build()) - .build(); - - assertThatThrownBy(() -> mcpSyncServer.removePrompt("nonexistent-prompt")).isInstanceOf(McpError.class) - .hasMessage("Prompt with name 'nonexistent-prompt' not found"); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - // --------------------------------------- - // Roots Tests - // --------------------------------------- - - @Test - void testRootsChangeConsumers() { - // Test with single consumer - var rootsReceived = new McpSchema.Root[1]; - var consumerCalled = new boolean[1]; - - var singleConsumerServer = McpServer.sync(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .rootsChangeConsumers(List.of(roots -> { - consumerCalled[0] = true; - if (!roots.isEmpty()) { - rootsReceived[0] = roots.get(0); - } - })) - .build(); - - assertThat(singleConsumerServer).isNotNull(); - assertThatCode(() -> singleConsumerServer.closeGracefully()).doesNotThrowAnyException(); - onClose(); - - // Test with multiple consumers - var consumer1Called = new boolean[1]; - var consumer2Called = new boolean[1]; - var rootsContent = new List[1]; - - var multipleConsumersServer = McpServer.sync(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .rootsChangeConsumers(List.of(roots -> { - consumer1Called[0] = true; - rootsContent[0] = roots; - }, roots -> consumer2Called[0] = true)) - .build(); - - assertThat(multipleConsumersServer).isNotNull(); - assertThatCode(() -> multipleConsumersServer.closeGracefully()).doesNotThrowAnyException(); - onClose(); - - // Test error handling - var errorHandlingServer = McpServer.sync(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .rootsChangeConsumers(List.of(roots -> { - throw new RuntimeException("Test error"); - })) - .build(); - - assertThat(errorHandlingServer).isNotNull(); - assertThatCode(() -> errorHandlingServer.closeGracefully()).doesNotThrowAnyException(); - onClose(); - - // Test without consumers - var noConsumersServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); - - assertThat(noConsumersServer).isNotNull(); - assertThatCode(() -> noConsumersServer.closeGracefully()).doesNotThrowAnyException(); - } - - // --------------------------------------- - // Logging Tests - // --------------------------------------- - - @Test - void testLoggingLevels() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().logging().build()) - .build(); - - // Test all logging levels - for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { - var notification = McpSchema.LoggingMessageNotification.builder() - .level(level) - .logger("test-logger") - .data("Test message with level " + level) - .build(); - - assertThatCode(() -> mcpSyncServer.loggingNotification(notification)).doesNotThrowAnyException(); - } - } - - @Test - void testLoggingWithoutCapability() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().build()) // No logging capability - .build(); - - var notification = McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.INFO) - .logger("test-logger") - .data("Test log message") - .build(); - - assertThatCode(() -> mcpSyncServer.loggingNotification(notification)).doesNotThrowAnyException(); - } - - @Test - void testLoggingWithNullNotification() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().logging().build()) - .build(); - - assertThatThrownBy(() -> mcpSyncServer.loggingNotification(null)).isInstanceOf(McpError.class); - } - -} diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java index 9cbef0500..379b47e23 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java @@ -14,10 +14,10 @@ import java.util.function.Function; import com.fasterxml.jackson.core.type.TypeReference; -import io.modelcontextprotocol.spec.ClientMcpTransport; import io.modelcontextprotocol.spec.McpClientSession; import io.modelcontextprotocol.spec.McpClientSession.NotificationHandler; import io.modelcontextprotocol.spec.McpClientSession.RequestHandler; +import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; @@ -153,7 +153,7 @@ public class McpAsyncClient { * @param initializationTimeout the max timeout to await for the client-server * @param features the MCP Client supported features. */ - McpAsyncClient(ClientMcpTransport transport, Duration requestTimeout, Duration initializationTimeout, + McpAsyncClient(McpClientTransport transport, Duration requestTimeout, Duration initializationTimeout, McpClientFeatures.Async features) { Assert.notNull(transport, "Transport must not be null"); diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java index 9c5f7b015..f7b179616 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java @@ -12,7 +12,6 @@ import java.util.function.Consumer; import java.util.function.Function; -import io.modelcontextprotocol.spec.ClientMcpTransport; import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpTransport; @@ -102,26 +101,6 @@ */ public interface McpClient { - /** - * Start building a synchronous MCP client with the specified transport layer. The - * synchronous MCP client provides blocking operations. Synchronous clients wait for - * each operation to complete before returning, making them simpler to use but - * potentially less performant for concurrent operations. The transport layer handles - * the low-level communication between client and server using protocols like stdio or - * Server-Sent Events (SSE). - * @param transport The transport layer implementation for MCP communication. Common - * implementations include {@code StdioClientTransport} for stdio-based communication - * and {@code SseClientTransport} for SSE-based communication. - * @return A new builder instance for configuring the client - * @throws IllegalArgumentException if transport is null - * @deprecated This method will be removed in 0.9.0. Use - * {@link #sync(McpClientTransport)} - */ - @Deprecated - static SyncSpec sync(ClientMcpTransport transport) { - return new SyncSpec(transport); - } - /** * Start building a synchronous MCP client with the specified transport layer. The * synchronous MCP client provides blocking operations. Synchronous clients wait for @@ -139,26 +118,6 @@ static SyncSpec sync(McpClientTransport transport) { return new SyncSpec(transport); } - /** - * Start building an asynchronous MCP client with the specified transport layer. The - * asynchronous MCP client provides non-blocking operations. Asynchronous clients - * return reactive primitives (Mono/Flux) immediately, allowing for concurrent - * operations and reactive programming patterns. The transport layer handles the - * low-level communication between client and server using protocols like stdio or - * Server-Sent Events (SSE). - * @param transport The transport layer implementation for MCP communication. Common - * implementations include {@code StdioClientTransport} for stdio-based communication - * and {@code SseClientTransport} for SSE-based communication. - * @return A new builder instance for configuring the client - * @throws IllegalArgumentException if transport is null - * @deprecated This method will be removed in 0.9.0. Use - * {@link #async(McpClientTransport)} - */ - @Deprecated - static AsyncSpec async(ClientMcpTransport transport) { - return new AsyncSpec(transport); - } - /** * Start building an asynchronous MCP client with the specified transport layer. The * asynchronous MCP client provides non-blocking operations. Asynchronous clients @@ -194,7 +153,7 @@ static AsyncSpec async(McpClientTransport transport) { */ class SyncSpec { - private final ClientMcpTransport transport; + private final McpClientTransport transport; private Duration requestTimeout = Duration.ofSeconds(20); // Default timeout @@ -216,7 +175,7 @@ class SyncSpec { private Function samplingHandler; - private SyncSpec(ClientMcpTransport transport) { + private SyncSpec(McpClientTransport transport) { Assert.notNull(transport, "Transport must not be null"); this.transport = transport; } @@ -433,7 +392,7 @@ public McpSyncClient build() { */ class AsyncSpec { - private final ClientMcpTransport transport; + private final McpClientTransport transport; private Duration requestTimeout = Duration.ofSeconds(20); // Default timeout @@ -455,7 +414,7 @@ class AsyncSpec { private Function> samplingHandler; - private AsyncSpec(ClientMcpTransport transport) { + private AsyncSpec(McpClientTransport transport) { Assert.notNull(transport, "Transport must not be null"); this.transport = transport; } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java index ec0a0dfdb..071d76462 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java @@ -66,12 +66,8 @@ public class McpSyncClient implements AutoCloseable { * Create a new McpSyncClient with the given delegate. * @param delegate the asynchronous kernel on top of which this synchronous client * provides a blocking API. - * @deprecated This method will be removed in 0.9.0. Use - * {@link McpClient#sync(McpClientTransport)} to obtain an instance. */ - @Deprecated - // TODO make the constructor package private post-deprecation - public McpSyncClient(McpAsyncClient delegate) { + McpSyncClient(McpAsyncClient delegate) { Assert.notNull(delegate, "The delegate can not be null"); this.delegate = delegate; } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java index ef69539ad..188b0f48e 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -4,9 +4,7 @@ package io.modelcontextprotocol.server; -import java.time.Duration; import java.util.HashMap; -import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Optional; @@ -14,21 +12,18 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CopyOnWriteArrayList; import java.util.function.BiFunction; -import java.util.function.Function; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.spec.McpClientSession; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpServerTransportProvider; -import io.modelcontextprotocol.spec.McpServerSession; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; -import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; import io.modelcontextprotocol.spec.McpSchema.Tool; -import io.modelcontextprotocol.spec.ServerMcpTransport; +import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.spec.McpServerTransportProvider; import io.modelcontextprotocol.util.Utils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -86,19 +81,6 @@ public class McpAsyncServer { this.delegate = null; } - /** - * Create a new McpAsyncServer with the given transport and capabilities. - * @param mcpTransport The transport layer implementation for MCP communication. - * @param features The MCP server supported features. - * @deprecated This constructor will beremoved in 0.9.0. Use - * {@link #McpAsyncServer(McpServerTransportProvider, ObjectMapper, McpServerFeatures.Async)} - * instead. - */ - @Deprecated - McpAsyncServer(ServerMcpTransport mcpTransport, McpServerFeatures.Async features) { - this.delegate = new LegacyAsyncServer(mcpTransport, features); - } - /** * Create a new McpAsyncServer with the given transport provider and capabilities. * @param mcpTransportProvider The transport layer implementation for MCP @@ -127,28 +109,6 @@ public McpSchema.Implementation getServerInfo() { return this.delegate.getServerInfo(); } - /** - * Get the client capabilities that define the supported features and functionality. - * @return The client capabilities - * @deprecated This will be removed in 0.9.0. Use - * {@link McpAsyncServerExchange#getClientCapabilities()}. - */ - @Deprecated - public ClientCapabilities getClientCapabilities() { - return this.delegate.getClientCapabilities(); - } - - /** - * Get the client implementation information. - * @return The client implementation details - * @deprecated This will be removed in 0.9.0. Use - * {@link McpAsyncServerExchange#getClientInfo()}. - */ - @Deprecated - public McpSchema.Implementation getClientInfo() { - return this.delegate.getClientInfo(); - } - /** * Gracefully closes the server, allowing any in-progress operations to complete. * @return A Mono that completes when the server has been closed @@ -164,45 +124,9 @@ public void close() { this.delegate.close(); } - /** - * Retrieves the list of all roots provided by the client. - * @return A Mono that emits the list of roots result. - * @deprecated This will be removed in 0.9.0. Use - * {@link McpAsyncServerExchange#listRoots()}. - */ - @Deprecated - public Mono listRoots() { - return this.delegate.listRoots(null); - } - - /** - * Retrieves a paginated list of roots provided by the server. - * @param cursor Optional pagination cursor from a previous list request - * @return A Mono that emits the list of roots result containing - * @deprecated This will be removed in 0.9.0. Use - * {@link McpAsyncServerExchange#listRoots(String)}. - */ - @Deprecated - public Mono listRoots(String cursor) { - return this.delegate.listRoots(cursor); - } - // --------------------------------------- // Tool Management // --------------------------------------- - - /** - * Add a new tool registration at runtime. - * @param toolRegistration The tool registration to add - * @return Mono that completes when clients have been notified of the change - * @deprecated This method will be removed in 0.9.0. Use - * {@link #addTool(McpServerFeatures.AsyncToolSpecification)}. - */ - @Deprecated - public Mono addTool(McpServerFeatures.AsyncToolRegistration toolRegistration) { - return this.delegate.addTool(toolRegistration); - } - /** * Add a new tool specification at runtime. * @param toolSpecification The tool specification to add @@ -232,19 +156,6 @@ public Mono notifyToolsListChanged() { // --------------------------------------- // Resource Management // --------------------------------------- - - /** - * Add a new resource handler at runtime. - * @param resourceHandler The resource handler to add - * @return Mono that completes when clients have been notified of the change - * @deprecated This method will be removed in 0.9.0. Use - * {@link #addResource(McpServerFeatures.AsyncResourceSpecification)}. - */ - @Deprecated - public Mono addResource(McpServerFeatures.AsyncResourceRegistration resourceHandler) { - return this.delegate.addResource(resourceHandler); - } - /** * Add a new resource handler at runtime. * @param resourceHandler The resource handler to add @@ -274,19 +185,6 @@ public Mono notifyResourcesListChanged() { // --------------------------------------- // Prompt Management // --------------------------------------- - - /** - * Add a new prompt handler at runtime. - * @param promptRegistration The prompt handler to add - * @return Mono that completes when clients have been notified of the change - * @deprecated This method will be removed in 0.9.0. Use - * {@link #addPrompt(McpServerFeatures.AsyncPromptSpecification)}. - */ - @Deprecated - public Mono addPrompt(McpServerFeatures.AsyncPromptRegistration promptRegistration) { - return this.delegate.addPrompt(promptRegistration); - } - /** * Add a new prompt handler at runtime. * @param promptSpecification The prompt handler to add @@ -330,33 +228,6 @@ public Mono loggingNotification(LoggingMessageNotification loggingMessageN // --------------------------------------- // Sampling // --------------------------------------- - - /** - * Create a new message using the sampling capabilities of the client. The Model - * Context Protocol (MCP) provides a standardized way for servers to request LLM - * sampling (“completions” or “generations”) from language models via clients. This - * flow allows clients to maintain control over model access, selection, and - * permissions while enabling servers to leverage AI capabilities—with no server API - * keys necessary. Servers can request text or image-based interactions and optionally - * include context from MCP servers in their prompts. - * @param createMessageRequest The request to create a new message - * @return A Mono that completes when the message has been created - * @throws McpError if the client has not been initialized or does not support - * sampling capabilities - * @throws McpError if the client does not support the createMessage method - * @see McpSchema.CreateMessageRequest - * @see McpSchema.CreateMessageResult - * @see Sampling - * Specification - * @deprecated This will be removed in 0.9.0. Use - * {@link McpAsyncServerExchange#createMessage(McpSchema.CreateMessageRequest)}. - */ - @Deprecated - public Mono createMessage(McpSchema.CreateMessageRequest createMessageRequest) { - return this.delegate.createMessage(createMessageRequest); - } - /** * This method is package-private and used for test only. Should not be called by user * code. @@ -492,18 +363,6 @@ public McpSchema.Implementation getServerInfo() { return this.serverInfo; } - @Override - @Deprecated - public ClientCapabilities getClientCapabilities() { - throw new IllegalStateException("This method is deprecated and should not be called"); - } - - @Override - @Deprecated - public McpSchema.Implementation getClientInfo() { - throw new IllegalStateException("This method is deprecated and should not be called"); - } - @Override public Mono closeGracefully() { return this.mcpTransportProvider.closeGracefully(); @@ -514,18 +373,6 @@ public void close() { this.mcpTransportProvider.close(); } - @Override - @Deprecated - public Mono listRoots() { - return this.listRoots(null); - } - - @Override - @Deprecated - public Mono listRoots(String cursor) { - return Mono.error(new RuntimeException("Not implemented")); - } - private McpServerSession.NotificationHandler asyncRootsListChangedNotificationHandler( List, Mono>> rootsChangeConsumers) { return (exchange, params) -> exchange.listRoots() @@ -574,11 +421,6 @@ public Mono addTool(McpServerFeatures.AsyncToolSpecification toolSpecifica }); } - @Override - public Mono addTool(McpServerFeatures.AsyncToolRegistration toolRegistration) { - return this.addTool(toolRegistration.toSpecification()); - } - @Override public Mono removeTool(String toolName) { if (toolName == null) { @@ -661,11 +503,6 @@ public Mono addResource(McpServerFeatures.AsyncResourceSpecification resou }); } - @Override - public Mono addResource(McpServerFeatures.AsyncResourceRegistration resourceHandler) { - return this.addResource(resourceHandler.toSpecification()); - } - @Override public Mono removeResource(String resourceUri) { if (resourceUri == null) { @@ -756,11 +593,6 @@ public Mono addPrompt(McpServerFeatures.AsyncPromptSpecification promptSpe }); } - @Override - public Mono addPrompt(McpServerFeatures.AsyncPromptRegistration promptRegistration) { - return this.addPrompt(promptRegistration.toSpecification()); - } - @Override public Mono removePrompt(String promptName) { if (promptName == null) { @@ -859,648 +691,6 @@ private McpServerSession.RequestHandler setLoggerRequestHandler() { // --------------------------------------- @Override - @Deprecated - public Mono createMessage(McpSchema.CreateMessageRequest createMessageRequest) { - return Mono.error(new RuntimeException("Not implemented")); - } - - @Override - void setProtocolVersions(List protocolVersions) { - this.protocolVersions = protocolVersions; - } - - } - - private static final class LegacyAsyncServer extends McpAsyncServer { - - /** - * The MCP session implementation that manages bidirectional JSON-RPC - * communication between clients and servers. - */ - private final McpClientSession mcpSession; - - private final ServerMcpTransport transport; - - private final McpSchema.ServerCapabilities serverCapabilities; - - private final McpSchema.Implementation serverInfo; - - private McpSchema.ClientCapabilities clientCapabilities; - - private McpSchema.Implementation clientInfo; - - /** - * Thread-safe list of tool handlers that can be modified at runtime. - */ - private final CopyOnWriteArrayList tools = new CopyOnWriteArrayList<>(); - - private final CopyOnWriteArrayList resourceTemplates = new CopyOnWriteArrayList<>(); - - private final ConcurrentHashMap resources = new ConcurrentHashMap<>(); - - private final ConcurrentHashMap prompts = new ConcurrentHashMap<>(); - - private LoggingLevel minLoggingLevel = LoggingLevel.DEBUG; - - /** - * Supported protocol versions. - */ - private List protocolVersions = List.of(McpSchema.LATEST_PROTOCOL_VERSION); - - /** - * Create a new McpAsyncServer with the given transport and capabilities. - * @param mcpTransport The transport layer implementation for MCP communication. - * @param features The MCP server supported features. - */ - LegacyAsyncServer(ServerMcpTransport mcpTransport, McpServerFeatures.Async features) { - - this.serverInfo = features.serverInfo(); - this.serverCapabilities = features.serverCapabilities(); - this.tools.addAll(features.tools()); - this.resources.putAll(features.resources()); - this.resourceTemplates.addAll(features.resourceTemplates()); - this.prompts.putAll(features.prompts()); - - Map> requestHandlers = new HashMap<>(); - - // Initialize request handlers for standard MCP methods - requestHandlers.put(McpSchema.METHOD_INITIALIZE, asyncInitializeRequestHandler()); - - // Ping MUST respond with an empty data, but not NULL response. - requestHandlers.put(McpSchema.METHOD_PING, (params) -> Mono.just(Map.of())); - - // Add tools API handlers if the tool capability is enabled - if (this.serverCapabilities.tools() != null) { - requestHandlers.put(McpSchema.METHOD_TOOLS_LIST, toolsListRequestHandler()); - requestHandlers.put(McpSchema.METHOD_TOOLS_CALL, toolsCallRequestHandler()); - } - - // Add resources API handlers if provided - if (this.serverCapabilities.resources() != null) { - requestHandlers.put(McpSchema.METHOD_RESOURCES_LIST, resourcesListRequestHandler()); - requestHandlers.put(McpSchema.METHOD_RESOURCES_READ, resourcesReadRequestHandler()); - requestHandlers.put(McpSchema.METHOD_RESOURCES_TEMPLATES_LIST, resourceTemplateListRequestHandler()); - } - - // Add prompts API handlers if provider exists - if (this.serverCapabilities.prompts() != null) { - requestHandlers.put(McpSchema.METHOD_PROMPT_LIST, promptsListRequestHandler()); - requestHandlers.put(McpSchema.METHOD_PROMPT_GET, promptsGetRequestHandler()); - } - - // Add logging API handlers if the logging capability is enabled - if (this.serverCapabilities.logging() != null) { - requestHandlers.put(McpSchema.METHOD_LOGGING_SET_LEVEL, setLoggerRequestHandler()); - } - - Map notificationHandlers = new HashMap<>(); - - notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_INITIALIZED, (params) -> Mono.empty()); - - List, Mono>> rootsChangeHandlers = features - .rootsChangeConsumers(); - - List, Mono>> rootsChangeConsumers = rootsChangeHandlers.stream() - .map(handler -> (Function, Mono>) (roots) -> handler.apply(null, roots)) - .toList(); - - if (Utils.isEmpty(rootsChangeConsumers)) { - rootsChangeConsumers = List.of((roots) -> Mono.fromRunnable(() -> logger.warn( - "Roots list changed notification, but no consumers provided. Roots list changed: {}", roots))); - } - - notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED, - asyncRootsListChangedNotificationHandler(rootsChangeConsumers)); - - this.transport = mcpTransport; - this.mcpSession = new McpClientSession(Duration.ofSeconds(10), mcpTransport, requestHandlers, - notificationHandlers); - } - - @Override - public Mono addTool(McpServerFeatures.AsyncToolSpecification toolSpecification) { - throw new IllegalArgumentException( - "McpAsyncServer configured with legacy " + "transport. Use McpServerTransportProvider instead."); - } - - @Override - public Mono addResource(McpServerFeatures.AsyncResourceSpecification resourceHandler) { - throw new IllegalArgumentException( - "McpAsyncServer configured with legacy " + "transport. Use McpServerTransportProvider instead."); - } - - @Override - public Mono addPrompt(McpServerFeatures.AsyncPromptSpecification promptSpecification) { - throw new IllegalArgumentException( - "McpAsyncServer configured with legacy " + "transport. Use McpServerTransportProvider instead."); - } - - // --------------------------------------- - // Lifecycle Management - // --------------------------------------- - private McpClientSession.RequestHandler asyncInitializeRequestHandler() { - return params -> { - McpSchema.InitializeRequest initializeRequest = transport.unmarshalFrom(params, - new TypeReference() { - }); - this.clientCapabilities = initializeRequest.capabilities(); - this.clientInfo = initializeRequest.clientInfo(); - logger.info("Client initialize request - Protocol: {}, Capabilities: {}, Info: {}", - initializeRequest.protocolVersion(), initializeRequest.capabilities(), - initializeRequest.clientInfo()); - - // The server MUST respond with the highest protocol version it supports - // if - // it does not support the requested (e.g. Client) version. - String serverProtocolVersion = this.protocolVersions.get(this.protocolVersions.size() - 1); - - if (this.protocolVersions.contains(initializeRequest.protocolVersion())) { - // If the server supports the requested protocol version, it MUST - // respond - // with the same version. - serverProtocolVersion = initializeRequest.protocolVersion(); - } - else { - logger.warn( - "Client requested unsupported protocol version: {}, so the server will sugggest the {} version instead", - initializeRequest.protocolVersion(), serverProtocolVersion); - } - - return Mono.just(new McpSchema.InitializeResult(serverProtocolVersion, this.serverCapabilities, - this.serverInfo, null)); - }; - } - - /** - * Get the server capabilities that define the supported features and - * functionality. - * @return The server capabilities - */ - public McpSchema.ServerCapabilities getServerCapabilities() { - return this.serverCapabilities; - } - - /** - * Get the server implementation information. - * @return The server implementation details - */ - public McpSchema.Implementation getServerInfo() { - return this.serverInfo; - } - - /** - * Get the client capabilities that define the supported features and - * functionality. - * @return The client capabilities - */ - public ClientCapabilities getClientCapabilities() { - return this.clientCapabilities; - } - - /** - * Get the client implementation information. - * @return The client implementation details - */ - public McpSchema.Implementation getClientInfo() { - return this.clientInfo; - } - - /** - * Gracefully closes the server, allowing any in-progress operations to complete. - * @return A Mono that completes when the server has been closed - */ - public Mono closeGracefully() { - return this.mcpSession.closeGracefully(); - } - - /** - * Close the server immediately. - */ - public void close() { - this.mcpSession.close(); - } - - private static final TypeReference LIST_ROOTS_RESULT_TYPE_REF = new TypeReference<>() { - }; - - /** - * Retrieves the list of all roots provided by the client. - * @return A Mono that emits the list of roots result. - */ - public Mono listRoots() { - return this.listRoots(null); - } - - /** - * Retrieves a paginated list of roots provided by the server. - * @param cursor Optional pagination cursor from a previous list request - * @return A Mono that emits the list of roots result containing - */ - public Mono listRoots(String cursor) { - return this.mcpSession.sendRequest(McpSchema.METHOD_ROOTS_LIST, new McpSchema.PaginatedRequest(cursor), - LIST_ROOTS_RESULT_TYPE_REF); - } - - private McpClientSession.NotificationHandler asyncRootsListChangedNotificationHandler( - List, Mono>> rootsChangeConsumers) { - return params -> listRoots().flatMap(listRootsResult -> Flux.fromIterable(rootsChangeConsumers) - .flatMap(consumer -> consumer.apply(listRootsResult.roots())) - .onErrorResume(error -> { - logger.error("Error handling roots list change notification", error); - return Mono.empty(); - }) - .then()); - } - - // --------------------------------------- - // Tool Management - // --------------------------------------- - - /** - * Add a new tool registration at runtime. - * @param toolRegistration The tool registration to add - * @return Mono that completes when clients have been notified of the change - */ - @Override - public Mono addTool(McpServerFeatures.AsyncToolRegistration toolRegistration) { - if (toolRegistration == null) { - return Mono.error(new McpError("Tool registration must not be null")); - } - if (toolRegistration.tool() == null) { - return Mono.error(new McpError("Tool must not be null")); - } - if (toolRegistration.call() == null) { - return Mono.error(new McpError("Tool call handler must not be null")); - } - if (this.serverCapabilities.tools() == null) { - return Mono.error(new McpError("Server must be configured with tool capabilities")); - } - - return Mono.defer(() -> { - // Check for duplicate tool names - if (this.tools.stream().anyMatch(th -> th.tool().name().equals(toolRegistration.tool().name()))) { - return Mono - .error(new McpError("Tool with name '" + toolRegistration.tool().name() + "' already exists")); - } - - this.tools.add(toolRegistration.toSpecification()); - logger.debug("Added tool handler: {}", toolRegistration.tool().name()); - - if (this.serverCapabilities.tools().listChanged()) { - return notifyToolsListChanged(); - } - return Mono.empty(); - }); - } - - /** - * Remove a tool handler at runtime. - * @param toolName The name of the tool handler to remove - * @return Mono that completes when clients have been notified of the change - */ - public Mono removeTool(String toolName) { - if (toolName == null) { - return Mono.error(new McpError("Tool name must not be null")); - } - if (this.serverCapabilities.tools() == null) { - return Mono.error(new McpError("Server must be configured with tool capabilities")); - } - - return Mono.defer(() -> { - boolean removed = this.tools - .removeIf(toolRegistration -> toolRegistration.tool().name().equals(toolName)); - if (removed) { - logger.debug("Removed tool handler: {}", toolName); - if (this.serverCapabilities.tools().listChanged()) { - return notifyToolsListChanged(); - } - return Mono.empty(); - } - return Mono.error(new McpError("Tool with name '" + toolName + "' not found")); - }); - } - - /** - * Notifies clients that the list of available tools has changed. - * @return A Mono that completes when all clients have been notified - */ - public Mono notifyToolsListChanged() { - return this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_TOOLS_LIST_CHANGED, null); - } - - private McpClientSession.RequestHandler toolsListRequestHandler() { - return params -> { - List tools = this.tools.stream().map(McpServerFeatures.AsyncToolSpecification::tool).toList(); - - return Mono.just(new McpSchema.ListToolsResult(tools, null)); - }; - } - - private McpClientSession.RequestHandler toolsCallRequestHandler() { - return params -> { - McpSchema.CallToolRequest callToolRequest = transport.unmarshalFrom(params, - new TypeReference() { - }); - - Optional toolRegistration = this.tools.stream() - .filter(tr -> callToolRequest.name().equals(tr.tool().name())) - .findAny(); - - if (toolRegistration.isEmpty()) { - return Mono.error(new McpError("Tool not found: " + callToolRequest.name())); - } - - return toolRegistration.map(tool -> tool.call().apply(null, callToolRequest.arguments())) - .orElse(Mono.error(new McpError("Tool not found: " + callToolRequest.name()))); - }; - } - - // --------------------------------------- - // Resource Management - // --------------------------------------- - - /** - * Add a new resource handler at runtime. - * @param resourceHandler The resource handler to add - * @return Mono that completes when clients have been notified of the change - */ - @Override - public Mono addResource(McpServerFeatures.AsyncResourceRegistration resourceHandler) { - if (resourceHandler == null || resourceHandler.resource() == null) { - return Mono.error(new McpError("Resource must not be null")); - } - - if (this.serverCapabilities.resources() == null) { - return Mono.error(new McpError("Server must be configured with resource capabilities")); - } - - return Mono.defer(() -> { - if (this.resources.putIfAbsent(resourceHandler.resource().uri(), - resourceHandler.toSpecification()) != null) { - return Mono.error(new McpError( - "Resource with URI '" + resourceHandler.resource().uri() + "' already exists")); - } - logger.debug("Added resource handler: {}", resourceHandler.resource().uri()); - if (this.serverCapabilities.resources().listChanged()) { - return notifyResourcesListChanged(); - } - return Mono.empty(); - }); - } - - /** - * Remove a resource handler at runtime. - * @param resourceUri The URI of the resource handler to remove - * @return Mono that completes when clients have been notified of the change - */ - public Mono removeResource(String resourceUri) { - if (resourceUri == null) { - return Mono.error(new McpError("Resource URI must not be null")); - } - if (this.serverCapabilities.resources() == null) { - return Mono.error(new McpError("Server must be configured with resource capabilities")); - } - - return Mono.defer(() -> { - McpServerFeatures.AsyncResourceSpecification removed = this.resources.remove(resourceUri); - if (removed != null) { - logger.debug("Removed resource handler: {}", resourceUri); - if (this.serverCapabilities.resources().listChanged()) { - return notifyResourcesListChanged(); - } - return Mono.empty(); - } - return Mono.error(new McpError("Resource with URI '" + resourceUri + "' not found")); - }); - } - - /** - * Notifies clients that the list of available resources has changed. - * @return A Mono that completes when all clients have been notified - */ - public Mono notifyResourcesListChanged() { - return this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_RESOURCES_LIST_CHANGED, null); - } - - private McpClientSession.RequestHandler resourcesListRequestHandler() { - return params -> { - var resourceList = this.resources.values() - .stream() - .map(McpServerFeatures.AsyncResourceSpecification::resource) - .toList(); - return Mono.just(new McpSchema.ListResourcesResult(resourceList, null)); - }; - } - - private McpClientSession.RequestHandler resourceTemplateListRequestHandler() { - return params -> Mono.just(new McpSchema.ListResourceTemplatesResult(this.resourceTemplates, null)); - - } - - private McpClientSession.RequestHandler resourcesReadRequestHandler() { - return params -> { - McpSchema.ReadResourceRequest resourceRequest = transport.unmarshalFrom(params, - new TypeReference() { - }); - var resourceUri = resourceRequest.uri(); - McpServerFeatures.AsyncResourceSpecification registration = this.resources.get(resourceUri); - if (registration != null) { - return registration.readHandler().apply(null, resourceRequest); - } - return Mono.error(new McpError("Resource not found: " + resourceUri)); - }; - } - - // --------------------------------------- - // Prompt Management - // --------------------------------------- - - /** - * Add a new prompt handler at runtime. - * @param promptRegistration The prompt handler to add - * @return Mono that completes when clients have been notified of the change - */ - @Override - public Mono addPrompt(McpServerFeatures.AsyncPromptRegistration promptRegistration) { - if (promptRegistration == null) { - return Mono.error(new McpError("Prompt registration must not be null")); - } - if (this.serverCapabilities.prompts() == null) { - return Mono.error(new McpError("Server must be configured with prompt capabilities")); - } - - return Mono.defer(() -> { - McpServerFeatures.AsyncPromptSpecification registration = this.prompts - .putIfAbsent(promptRegistration.prompt().name(), promptRegistration.toSpecification()); - if (registration != null) { - return Mono.error(new McpError( - "Prompt with name '" + promptRegistration.prompt().name() + "' already exists")); - } - - logger.debug("Added prompt handler: {}", promptRegistration.prompt().name()); - - // Servers that declared the listChanged capability SHOULD send a - // notification, - // when the list of available prompts changes - if (this.serverCapabilities.prompts().listChanged()) { - return notifyPromptsListChanged(); - } - return Mono.empty(); - }); - } - - /** - * Remove a prompt handler at runtime. - * @param promptName The name of the prompt handler to remove - * @return Mono that completes when clients have been notified of the change - */ - public Mono removePrompt(String promptName) { - if (promptName == null) { - return Mono.error(new McpError("Prompt name must not be null")); - } - if (this.serverCapabilities.prompts() == null) { - return Mono.error(new McpError("Server must be configured with prompt capabilities")); - } - - return Mono.defer(() -> { - McpServerFeatures.AsyncPromptSpecification removed = this.prompts.remove(promptName); - - if (removed != null) { - logger.debug("Removed prompt handler: {}", promptName); - // Servers that declared the listChanged capability SHOULD send a - // notification, when the list of available prompts changes - if (this.serverCapabilities.prompts().listChanged()) { - return this.notifyPromptsListChanged(); - } - return Mono.empty(); - } - return Mono.error(new McpError("Prompt with name '" + promptName + "' not found")); - }); - } - - /** - * Notifies clients that the list of available prompts has changed. - * @return A Mono that completes when all clients have been notified - */ - public Mono notifyPromptsListChanged() { - return this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_PROMPTS_LIST_CHANGED, null); - } - - private McpClientSession.RequestHandler promptsListRequestHandler() { - return params -> { - // TODO: Implement pagination - // McpSchema.PaginatedRequest request = transport.unmarshalFrom(params, - // new TypeReference() { - // }); - - var promptList = this.prompts.values() - .stream() - .map(McpServerFeatures.AsyncPromptSpecification::prompt) - .toList(); - - return Mono.just(new McpSchema.ListPromptsResult(promptList, null)); - }; - } - - private McpClientSession.RequestHandler promptsGetRequestHandler() { - return params -> { - McpSchema.GetPromptRequest promptRequest = transport.unmarshalFrom(params, - new TypeReference() { - }); - - // Implement prompt retrieval logic here - McpServerFeatures.AsyncPromptSpecification registration = this.prompts.get(promptRequest.name()); - if (registration == null) { - return Mono.error(new McpError("Prompt not found: " + promptRequest.name())); - } - - return registration.promptHandler().apply(null, promptRequest); - }; - } - - // --------------------------------------- - // Logging Management - // --------------------------------------- - - /** - * Send a logging message notification to all connected clients. Messages below - * the current minimum logging level will be filtered out. - * @param loggingMessageNotification The logging message to send - * @return A Mono that completes when the notification has been sent - */ - public Mono loggingNotification(LoggingMessageNotification loggingMessageNotification) { - - if (loggingMessageNotification == null) { - return Mono.error(new McpError("Logging message must not be null")); - } - - Map params = this.transport.unmarshalFrom(loggingMessageNotification, - new TypeReference>() { - }); - - if (loggingMessageNotification.level().level() < minLoggingLevel.level()) { - return Mono.empty(); - } - - return this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_MESSAGE, params); - } - - /** - * Handles requests to set the minimum logging level. Messages below this level - * will not be sent. - * @return A handler that processes logging level change requests - */ - private McpClientSession.RequestHandler setLoggerRequestHandler() { - return params -> { - this.minLoggingLevel = transport.unmarshalFrom(params, new TypeReference() { - }); - - return Mono.empty(); - }; - } - - // --------------------------------------- - // Sampling - // --------------------------------------- - private static final TypeReference CREATE_MESSAGE_RESULT_TYPE_REF = new TypeReference<>() { - }; - - /** - * Create a new message using the sampling capabilities of the client. The Model - * Context Protocol (MCP) provides a standardized way for servers to request LLM - * sampling (“completions” or “generations”) from language models via clients. - * This flow allows clients to maintain control over model access, selection, and - * permissions while enabling servers to leverage AI capabilities—with no server - * API keys necessary. Servers can request text or image-based interactions and - * optionally include context from MCP servers in their prompts. - * @param createMessageRequest The request to create a new message - * @return A Mono that completes when the message has been created - * @throws McpError if the client has not been initialized or does not support - * sampling capabilities - * @throws McpError if the client does not support the createMessage method - * @see McpSchema.CreateMessageRequest - * @see McpSchema.CreateMessageResult - * @see Sampling - * Specification - */ - public Mono createMessage(McpSchema.CreateMessageRequest createMessageRequest) { - - if (this.clientCapabilities == null) { - return Mono.error(new McpError("Client must be initialized. Call the initialize method first!")); - } - if (this.clientCapabilities.sampling() == null) { - return Mono.error(new McpError("Client must be configured with sampling capabilities")); - } - return this.mcpSession.sendRequest(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE, createMessageRequest, - CREATE_MESSAGE_RESULT_TYPE_REF); - } - - /** - * This method is package-private and used for test only. Should not be called by - * user code. - * @param protocolVersions the Client supported protocol versions. - */ void setProtocolVersions(List protocolVersions) { this.protocolVersions = protocolVersions; } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java index d8dfcb018..091efac2f 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java @@ -11,16 +11,12 @@ import java.util.Map; import java.util.function.BiConsumer; import java.util.function.BiFunction; -import java.util.function.Consumer; -import java.util.function.Function; -import java.util.stream.Collectors; import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpServerTransportProvider; -import io.modelcontextprotocol.spec.ServerMcpTransport; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; import io.modelcontextprotocol.spec.McpSchema.ResourceTemplate; +import io.modelcontextprotocol.spec.McpServerTransportProvider; import io.modelcontextprotocol.util.Assert; import reactor.core.publisher.Mono; @@ -136,21 +132,6 @@ static SyncSpecification sync(McpServerTransportProvider transportProvider) { return new SyncSpecification(transportProvider); } - /** - * Starts building a synchronous MCP server that provides blocking operations. - * Synchronous servers block the current Thread's execution upon each request before - * giving the control back to the caller, making them simpler to implement but - * potentially less scalable for concurrent operations. - * @param transport The transport layer implementation for MCP communication - * @return A new instance of {@link SyncSpec} for configuring the server. - * @deprecated This method will be removed in 0.9.0. Use - * {@link #sync(McpServerTransportProvider)} instead. - */ - @Deprecated - static SyncSpec sync(ServerMcpTransport transport) { - return new SyncSpec(transport); - } - /** * Starts building an asynchronous MCP server that provides non-blocking operations. * Asynchronous servers can handle multiple requests concurrently on a single Thread @@ -163,21 +144,6 @@ static AsyncSpecification async(McpServerTransportProvider transportProvider) { return new AsyncSpecification(transportProvider); } - /** - * Starts building an asynchronous MCP server that provides non-blocking operations. - * Asynchronous servers can handle multiple requests concurrently on a single Thread - * using a functional paradigm with non-blocking server transports, making them more - * scalable for high-concurrency scenarios but more complex to implement. - * @param transport The transport layer implementation for MCP communication - * @return A new instance of {@link AsyncSpec} for configuring the server. - * @deprecated This method will be removed in 0.9.0. Use - * {@link #async(McpServerTransportProvider)} instead. - */ - @Deprecated - static AsyncSpec async(ServerMcpTransport transport) { - return new AsyncSpec(transport); - } - /** * Asynchronous server specification. */ @@ -1004,819 +970,4 @@ public McpSyncServer build() { } - /** - * Asynchronous server specification. - * - * @deprecated - */ - @Deprecated - class AsyncSpec { - - private static final McpSchema.Implementation DEFAULT_SERVER_INFO = new McpSchema.Implementation("mcp-server", - "1.0.0"); - - private final ServerMcpTransport transport; - - private ObjectMapper objectMapper; - - private McpSchema.Implementation serverInfo = DEFAULT_SERVER_INFO; - - private McpSchema.ServerCapabilities serverCapabilities; - - /** - * The Model Context Protocol (MCP) allows servers to expose tools that can be - * invoked by language models. Tools enable models to interact with external - * systems, such as querying databases, calling APIs, or performing computations. - * Each tool is uniquely identified by a name and includes metadata describing its - * schema. - */ - private final List tools = new ArrayList<>(); - - /** - * The Model Context Protocol (MCP) provides a standardized way for servers to - * expose resources to clients. Resources allow servers to share data that - * provides context to language models, such as files, database schemas, or - * application-specific information. Each resource is uniquely identified by a - * URI. - */ - private final Map resources = new HashMap<>(); - - private final List resourceTemplates = new ArrayList<>(); - - /** - * The Model Context Protocol (MCP) provides a standardized way for servers to - * expose prompt templates to clients. Prompts allow servers to provide structured - * messages and instructions for interacting with language models. Clients can - * discover available prompts, retrieve their contents, and provide arguments to - * customize them. - */ - private final Map prompts = new HashMap<>(); - - private final List, Mono>> rootsChangeConsumers = new ArrayList<>(); - - private AsyncSpec(ServerMcpTransport transport) { - Assert.notNull(transport, "Transport must not be null"); - this.transport = transport; - } - - /** - * Sets the server implementation information that will be shared with clients - * during connection initialization. This helps with version compatibility, - * debugging, and server identification. - * @param serverInfo The server implementation details including name and version. - * Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if serverInfo is null - */ - public AsyncSpec serverInfo(McpSchema.Implementation serverInfo) { - Assert.notNull(serverInfo, "Server info must not be null"); - this.serverInfo = serverInfo; - return this; - } - - /** - * Sets the server implementation information using name and version strings. This - * is a convenience method alternative to - * {@link #serverInfo(McpSchema.Implementation)}. - * @param name The server name. Must not be null or empty. - * @param version The server version. Must not be null or empty. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if name or version is null or empty - * @see #serverInfo(McpSchema.Implementation) - */ - public AsyncSpec serverInfo(String name, String version) { - Assert.hasText(name, "Name must not be null or empty"); - Assert.hasText(version, "Version must not be null or empty"); - this.serverInfo = new McpSchema.Implementation(name, version); - return this; - } - - /** - * Sets the server capabilities that will be advertised to clients during - * connection initialization. Capabilities define what features the server - * supports, such as: - *
    - *
  • Tool execution - *
  • Resource access - *
  • Prompt handling - *
  • Streaming responses - *
  • Batch operations - *
- * @param serverCapabilities The server capabilities configuration. Must not be - * null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if serverCapabilities is null - */ - public AsyncSpec capabilities(McpSchema.ServerCapabilities serverCapabilities) { - this.serverCapabilities = serverCapabilities; - return this; - } - - /** - * Adds a single tool with its implementation handler to the server. This is a - * convenience method for registering individual tools without creating a - * {@link McpServerFeatures.AsyncToolRegistration} explicitly. - * - *

- * Example usage:

{@code
-		 * .tool(
-		 *     new Tool("calculator", "Performs calculations", schema),
-		 *     args -> Mono.just(new CallToolResult("Result: " + calculate(args)))
-		 * )
-		 * }
- * @param tool The tool definition including name, description, and schema. Must - * not be null. - * @param handler The function that implements the tool's logic. Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if tool or handler is null - */ - public AsyncSpec tool(McpSchema.Tool tool, Function, Mono> handler) { - Assert.notNull(tool, "Tool must not be null"); - Assert.notNull(handler, "Handler must not be null"); - - this.tools.add(new McpServerFeatures.AsyncToolRegistration(tool, handler)); - - return this; - } - - /** - * Adds multiple tools with their handlers to the server using a List. This method - * is useful when tools are dynamically generated or loaded from a configuration - * source. - * @param toolRegistrations The list of tool registrations to add. Must not be - * null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if toolRegistrations is null - * @see #tools(McpServerFeatures.AsyncToolRegistration...) - */ - public AsyncSpec tools(List toolRegistrations) { - Assert.notNull(toolRegistrations, "Tool handlers list must not be null"); - this.tools.addAll(toolRegistrations); - return this; - } - - /** - * Adds multiple tools with their handlers to the server using varargs. This - * method provides a convenient way to register multiple tools inline. - * - *

- * Example usage:

{@code
-		 * .tools(
-		 *     new McpServerFeatures.AsyncToolRegistration(calculatorTool, calculatorHandler),
-		 *     new McpServerFeatures.AsyncToolRegistration(weatherTool, weatherHandler),
-		 *     new McpServerFeatures.AsyncToolRegistration(fileManagerTool, fileManagerHandler)
-		 * )
-		 * }
- * @param toolRegistrations The tool registrations to add. Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if toolRegistrations is null - * @see #tools(List) - */ - public AsyncSpec tools(McpServerFeatures.AsyncToolRegistration... toolRegistrations) { - for (McpServerFeatures.AsyncToolRegistration tool : toolRegistrations) { - this.tools.add(tool); - } - return this; - } - - /** - * Registers multiple resources with their handlers using a Map. This method is - * useful when resources are dynamically generated or loaded from a configuration - * source. - * @param resourceRegsitrations Map of resource name to registration. Must not be - * null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if resourceRegsitrations is null - * @see #resources(McpServerFeatures.AsyncResourceRegistration...) - */ - public AsyncSpec resources(Map resourceRegsitrations) { - Assert.notNull(resourceRegsitrations, "Resource handlers map must not be null"); - this.resources.putAll(resourceRegsitrations); - return this; - } - - /** - * Registers multiple resources with their handlers using a List. This method is - * useful when resources need to be added in bulk from a collection. - * @param resourceRegsitrations List of resource registrations. Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if resourceRegsitrations is null - * @see #resources(McpServerFeatures.AsyncResourceRegistration...) - */ - public AsyncSpec resources(List resourceRegsitrations) { - Assert.notNull(resourceRegsitrations, "Resource handlers list must not be null"); - for (McpServerFeatures.AsyncResourceRegistration resource : resourceRegsitrations) { - this.resources.put(resource.resource().uri(), resource); - } - return this; - } - - /** - * Registers multiple resources with their handlers using varargs. This method - * provides a convenient way to register multiple resources inline. - * - *

- * Example usage:

{@code
-		 * .resources(
-		 *     new McpServerFeatures.AsyncResourceRegistration(fileResource, fileHandler),
-		 *     new McpServerFeatures.AsyncResourceRegistration(dbResource, dbHandler),
-		 *     new McpServerFeatures.AsyncResourceRegistration(apiResource, apiHandler)
-		 * )
-		 * }
- * @param resourceRegistrations The resource registrations to add. Must not be - * null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if resourceRegistrations is null - */ - public AsyncSpec resources(McpServerFeatures.AsyncResourceRegistration... resourceRegistrations) { - Assert.notNull(resourceRegistrations, "Resource handlers list must not be null"); - for (McpServerFeatures.AsyncResourceRegistration resource : resourceRegistrations) { - this.resources.put(resource.resource().uri(), resource); - } - return this; - } - - /** - * Sets the resource templates that define patterns for dynamic resource access. - * Templates use URI patterns with placeholders that can be filled at runtime. - * - *

- * Example usage:

{@code
-		 * .resourceTemplates(
-		 *     new ResourceTemplate("file://{path}", "Access files by path"),
-		 *     new ResourceTemplate("db://{table}/{id}", "Access database records")
-		 * )
-		 * }
- * @param resourceTemplates List of resource templates. If null, clears existing - * templates. - * @return This builder instance for method chaining - * @see #resourceTemplates(ResourceTemplate...) - */ - public AsyncSpec resourceTemplates(List resourceTemplates) { - this.resourceTemplates.addAll(resourceTemplates); - return this; - } - - /** - * Sets the resource templates using varargs for convenience. This is an - * alternative to {@link #resourceTemplates(List)}. - * @param resourceTemplates The resource templates to set. - * @return This builder instance for method chaining - * @see #resourceTemplates(List) - */ - public AsyncSpec resourceTemplates(ResourceTemplate... resourceTemplates) { - for (ResourceTemplate resourceTemplate : resourceTemplates) { - this.resourceTemplates.add(resourceTemplate); - } - return this; - } - - /** - * Registers multiple prompts with their handlers using a Map. This method is - * useful when prompts are dynamically generated or loaded from a configuration - * source. - * - *

- * Example usage:

{@code
-		 * .prompts(Map.of("analysis", new McpServerFeatures.AsyncPromptRegistration(
-		 *     new Prompt("analysis", "Code analysis template"),
-		 *     request -> Mono.just(new GetPromptResult(generateAnalysisPrompt(request)))
-		 * )));
-		 * }
- * @param prompts Map of prompt name to registration. Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if prompts is null - */ - public AsyncSpec prompts(Map prompts) { - this.prompts.putAll(prompts); - return this; - } - - /** - * Registers multiple prompts with their handlers using a List. This method is - * useful when prompts need to be added in bulk from a collection. - * @param prompts List of prompt registrations. Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if prompts is null - * @see #prompts(McpServerFeatures.AsyncPromptRegistration...) - */ - public AsyncSpec prompts(List prompts) { - for (McpServerFeatures.AsyncPromptRegistration prompt : prompts) { - this.prompts.put(prompt.prompt().name(), prompt); - } - return this; - } - - /** - * Registers multiple prompts with their handlers using varargs. This method - * provides a convenient way to register multiple prompts inline. - * - *

- * Example usage:

{@code
-		 * .prompts(
-		 *     new McpServerFeatures.AsyncPromptRegistration(analysisPrompt, analysisHandler),
-		 *     new McpServerFeatures.AsyncPromptRegistration(summaryPrompt, summaryHandler),
-		 *     new McpServerFeatures.AsyncPromptRegistration(reviewPrompt, reviewHandler)
-		 * )
-		 * }
- * @param prompts The prompt registrations to add. Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if prompts is null - */ - public AsyncSpec prompts(McpServerFeatures.AsyncPromptRegistration... prompts) { - for (McpServerFeatures.AsyncPromptRegistration prompt : prompts) { - this.prompts.put(prompt.prompt().name(), prompt); - } - return this; - } - - /** - * Registers a consumer that will be notified when the list of roots changes. This - * is useful for updating resource availability dynamically, such as when new - * files are added or removed. - * @param consumer The consumer to register. Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if consumer is null - */ - public AsyncSpec rootsChangeConsumer(Function, Mono> consumer) { - Assert.notNull(consumer, "Consumer must not be null"); - this.rootsChangeConsumers.add(consumer); - return this; - } - - /** - * Registers multiple consumers that will be notified when the list of roots - * changes. This method is useful when multiple consumers need to be registered at - * once. - * @param consumers The list of consumers to register. Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if consumers is null - */ - public AsyncSpec rootsChangeConsumers(List, Mono>> consumers) { - Assert.notNull(consumers, "Consumers list must not be null"); - this.rootsChangeConsumers.addAll(consumers); - return this; - } - - /** - * Registers multiple consumers that will be notified when the list of roots - * changes using varargs. This method provides a convenient way to register - * multiple consumers inline. - * @param consumers The consumers to register. Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if consumers is null - */ - public AsyncSpec rootsChangeConsumers( - @SuppressWarnings("unchecked") Function, Mono>... consumers) { - for (Function, Mono> consumer : consumers) { - this.rootsChangeConsumers.add(consumer); - } - return this; - } - - /** - * Builds an asynchronous MCP server that provides non-blocking operations. - * @return A new instance of {@link McpAsyncServer} configured with this builder's - * settings - */ - public McpAsyncServer build() { - var tools = this.tools.stream().map(McpServerFeatures.AsyncToolRegistration::toSpecification).toList(); - - var resources = this.resources.entrySet() - .stream() - .map(entry -> Map.entry(entry.getKey(), entry.getValue().toSpecification())) - .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); - - var prompts = this.prompts.entrySet() - .stream() - .map(entry -> Map.entry(entry.getKey(), entry.getValue().toSpecification())) - .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); - - var rootsChangeHandlers = this.rootsChangeConsumers.stream() - .map(consumer -> (BiFunction, Mono>) (exchange, - roots) -> consumer.apply(roots)) - .toList(); - - var features = new McpServerFeatures.Async(this.serverInfo, this.serverCapabilities, tools, resources, - this.resourceTemplates, prompts, rootsChangeHandlers); - - return new McpAsyncServer(this.transport, features); - } - - } - - /** - * Synchronous server specification. - * - * @deprecated - */ - @Deprecated - class SyncSpec { - - private static final McpSchema.Implementation DEFAULT_SERVER_INFO = new McpSchema.Implementation("mcp-server", - "1.0.0"); - - private final ServerMcpTransport transport; - - private final McpServerTransportProvider transportProvider; - - private ObjectMapper objectMapper; - - private McpSchema.Implementation serverInfo = DEFAULT_SERVER_INFO; - - private McpSchema.ServerCapabilities serverCapabilities; - - /** - * The Model Context Protocol (MCP) allows servers to expose tools that can be - * invoked by language models. Tools enable models to interact with external - * systems, such as querying databases, calling APIs, or performing computations. - * Each tool is uniquely identified by a name and includes metadata describing its - * schema. - */ - private final List tools = new ArrayList<>(); - - /** - * The Model Context Protocol (MCP) provides a standardized way for servers to - * expose resources to clients. Resources allow servers to share data that - * provides context to language models, such as files, database schemas, or - * application-specific information. Each resource is uniquely identified by a - * URI. - */ - private final Map resources = new HashMap<>(); - - private final List resourceTemplates = new ArrayList<>(); - - /** - * The Model Context Protocol (MCP) provides a standardized way for servers to - * expose prompt templates to clients. Prompts allow servers to provide structured - * messages and instructions for interacting with language models. Clients can - * discover available prompts, retrieve their contents, and provide arguments to - * customize them. - */ - private final Map prompts = new HashMap<>(); - - private final List>> rootsChangeConsumers = new ArrayList<>(); - - private SyncSpec(McpServerTransportProvider transportProvider) { - Assert.notNull(transportProvider, "Transport provider must not be null"); - this.transportProvider = transportProvider; - this.transport = null; - } - - private SyncSpec(ServerMcpTransport transport) { - Assert.notNull(transport, "Transport must not be null"); - this.transport = transport; - this.transportProvider = null; - } - - /** - * Sets the server implementation information that will be shared with clients - * during connection initialization. This helps with version compatibility, - * debugging, and server identification. - * @param serverInfo The server implementation details including name and version. - * Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if serverInfo is null - */ - public SyncSpec serverInfo(McpSchema.Implementation serverInfo) { - Assert.notNull(serverInfo, "Server info must not be null"); - this.serverInfo = serverInfo; - return this; - } - - /** - * Sets the server implementation information using name and version strings. This - * is a convenience method alternative to - * {@link #serverInfo(McpSchema.Implementation)}. - * @param name The server name. Must not be null or empty. - * @param version The server version. Must not be null or empty. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if name or version is null or empty - * @see #serverInfo(McpSchema.Implementation) - */ - public SyncSpec serverInfo(String name, String version) { - Assert.hasText(name, "Name must not be null or empty"); - Assert.hasText(version, "Version must not be null or empty"); - this.serverInfo = new McpSchema.Implementation(name, version); - return this; - } - - /** - * Sets the server capabilities that will be advertised to clients during - * connection initialization. Capabilities define what features the server - * supports, such as: - *
    - *
  • Tool execution - *
  • Resource access - *
  • Prompt handling - *
  • Streaming responses - *
  • Batch operations - *
- * @param serverCapabilities The server capabilities configuration. Must not be - * null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if serverCapabilities is null - */ - public SyncSpec capabilities(McpSchema.ServerCapabilities serverCapabilities) { - this.serverCapabilities = serverCapabilities; - return this; - } - - /** - * Adds a single tool with its implementation handler to the server. This is a - * convenience method for registering individual tools without creating a - * {@link McpServerFeatures.SyncToolRegistration} explicitly. - * - *

- * Example usage:

{@code
-		 * .tool(
-		 *     new Tool("calculator", "Performs calculations", schema),
-		 *     args -> new CallToolResult("Result: " + calculate(args))
-		 * )
-		 * }
- * @param tool The tool definition including name, description, and schema. Must - * not be null. - * @param handler The function that implements the tool's logic. Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if tool or handler is null - */ - public SyncSpec tool(McpSchema.Tool tool, Function, McpSchema.CallToolResult> handler) { - Assert.notNull(tool, "Tool must not be null"); - Assert.notNull(handler, "Handler must not be null"); - - this.tools.add(new McpServerFeatures.SyncToolRegistration(tool, handler)); - - return this; - } - - /** - * Adds multiple tools with their handlers to the server using a List. This method - * is useful when tools are dynamically generated or loaded from a configuration - * source. - * @param toolRegistrations The list of tool registrations to add. Must not be - * null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if toolRegistrations is null - * @see #tools(McpServerFeatures.SyncToolRegistration...) - */ - public SyncSpec tools(List toolRegistrations) { - Assert.notNull(toolRegistrations, "Tool handlers list must not be null"); - this.tools.addAll(toolRegistrations); - return this; - } - - /** - * Adds multiple tools with their handlers to the server using varargs. This - * method provides a convenient way to register multiple tools inline. - * - *

- * Example usage:

{@code
-		 * .tools(
-		 *     new ToolRegistration(calculatorTool, calculatorHandler),
-		 *     new ToolRegistration(weatherTool, weatherHandler),
-		 *     new ToolRegistration(fileManagerTool, fileManagerHandler)
-		 * )
-		 * }
- * @param toolRegistrations The tool registrations to add. Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if toolRegistrations is null - * @see #tools(List) - */ - public SyncSpec tools(McpServerFeatures.SyncToolRegistration... toolRegistrations) { - for (McpServerFeatures.SyncToolRegistration tool : toolRegistrations) { - this.tools.add(tool); - } - return this; - } - - /** - * Registers multiple resources with their handlers using a Map. This method is - * useful when resources are dynamically generated or loaded from a configuration - * source. - * @param resourceRegsitrations Map of resource name to registration. Must not be - * null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if resourceRegsitrations is null - * @see #resources(McpServerFeatures.SyncResourceRegistration...) - */ - public SyncSpec resources(Map resourceRegsitrations) { - Assert.notNull(resourceRegsitrations, "Resource handlers map must not be null"); - this.resources.putAll(resourceRegsitrations); - return this; - } - - /** - * Registers multiple resources with their handlers using a List. This method is - * useful when resources need to be added in bulk from a collection. - * @param resourceRegsitrations List of resource registrations. Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if resourceRegsitrations is null - * @see #resources(McpServerFeatures.SyncResourceRegistration...) - */ - public SyncSpec resources(List resourceRegsitrations) { - Assert.notNull(resourceRegsitrations, "Resource handlers list must not be null"); - for (McpServerFeatures.SyncResourceRegistration resource : resourceRegsitrations) { - this.resources.put(resource.resource().uri(), resource); - } - return this; - } - - /** - * Registers multiple resources with their handlers using varargs. This method - * provides a convenient way to register multiple resources inline. - * - *

- * Example usage:

{@code
-		 * .resources(
-		 *     new ResourceRegistration(fileResource, fileHandler),
-		 *     new ResourceRegistration(dbResource, dbHandler),
-		 *     new ResourceRegistration(apiResource, apiHandler)
-		 * )
-		 * }
- * @param resourceRegistrations The resource registrations to add. Must not be - * null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if resourceRegistrations is null - */ - public SyncSpec resources(McpServerFeatures.SyncResourceRegistration... resourceRegistrations) { - Assert.notNull(resourceRegistrations, "Resource handlers list must not be null"); - for (McpServerFeatures.SyncResourceRegistration resource : resourceRegistrations) { - this.resources.put(resource.resource().uri(), resource); - } - return this; - } - - /** - * Sets the resource templates that define patterns for dynamic resource access. - * Templates use URI patterns with placeholders that can be filled at runtime. - * - *

- * Example usage:

{@code
-		 * .resourceTemplates(
-		 *     new ResourceTemplate("file://{path}", "Access files by path"),
-		 *     new ResourceTemplate("db://{table}/{id}", "Access database records")
-		 * )
-		 * }
- * @param resourceTemplates List of resource templates. If null, clears existing - * templates. - * @return This builder instance for method chaining - * @see #resourceTemplates(ResourceTemplate...) - */ - public SyncSpec resourceTemplates(List resourceTemplates) { - this.resourceTemplates.addAll(resourceTemplates); - return this; - } - - /** - * Sets the resource templates using varargs for convenience. This is an - * alternative to {@link #resourceTemplates(List)}. - * @param resourceTemplates The resource templates to set. - * @return This builder instance for method chaining - * @see #resourceTemplates(List) - */ - public SyncSpec resourceTemplates(ResourceTemplate... resourceTemplates) { - for (ResourceTemplate resourceTemplate : resourceTemplates) { - this.resourceTemplates.add(resourceTemplate); - } - return this; - } - - /** - * Registers multiple prompts with their handlers using a Map. This method is - * useful when prompts are dynamically generated or loaded from a configuration - * source. - * - *

- * Example usage:

{@code
-		 * Map prompts = new HashMap<>();
-		 * prompts.put("analysis", new PromptRegistration(
-		 *     new Prompt("analysis", "Code analysis template"),
-		 *     request -> new GetPromptResult(generateAnalysisPrompt(request))
-		 * ));
-		 * .prompts(prompts)
-		 * }
- * @param prompts Map of prompt name to registration. Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if prompts is null - */ - public SyncSpec prompts(Map prompts) { - this.prompts.putAll(prompts); - return this; - } - - /** - * Registers multiple prompts with their handlers using a List. This method is - * useful when prompts need to be added in bulk from a collection. - * @param prompts List of prompt registrations. Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if prompts is null - * @see #prompts(McpServerFeatures.SyncPromptRegistration...) - */ - public SyncSpec prompts(List prompts) { - for (McpServerFeatures.SyncPromptRegistration prompt : prompts) { - this.prompts.put(prompt.prompt().name(), prompt); - } - return this; - } - - /** - * Registers multiple prompts with their handlers using varargs. This method - * provides a convenient way to register multiple prompts inline. - * - *

- * Example usage:

{@code
-		 * .prompts(
-		 *     new PromptRegistration(analysisPrompt, analysisHandler),
-		 *     new PromptRegistration(summaryPrompt, summaryHandler),
-		 *     new PromptRegistration(reviewPrompt, reviewHandler)
-		 * )
-		 * }
- * @param prompts The prompt registrations to add. Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if prompts is null - */ - public SyncSpec prompts(McpServerFeatures.SyncPromptRegistration... prompts) { - for (McpServerFeatures.SyncPromptRegistration prompt : prompts) { - this.prompts.put(prompt.prompt().name(), prompt); - } - return this; - } - - /** - * Registers a consumer that will be notified when the list of roots changes. This - * is useful for updating resource availability dynamically, such as when new - * files are added or removed. - * @param consumer The consumer to register. Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if consumer is null - */ - public SyncSpec rootsChangeConsumer(Consumer> consumer) { - Assert.notNull(consumer, "Consumer must not be null"); - this.rootsChangeConsumers.add(consumer); - return this; - } - - /** - * Registers multiple consumers that will be notified when the list of roots - * changes. This method is useful when multiple consumers need to be registered at - * once. - * @param consumers The list of consumers to register. Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if consumers is null - */ - public SyncSpec rootsChangeConsumers(List>> consumers) { - Assert.notNull(consumers, "Consumers list must not be null"); - this.rootsChangeConsumers.addAll(consumers); - return this; - } - - /** - * Registers multiple consumers that will be notified when the list of roots - * changes using varargs. This method provides a convenient way to register - * multiple consumers inline. - * @param consumers The consumers to register. Must not be null. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if consumers is null - */ - public SyncSpec rootsChangeConsumers(Consumer>... consumers) { - for (Consumer> consumer : consumers) { - this.rootsChangeConsumers.add(consumer); - } - return this; - } - - /** - * Builds a synchronous MCP server that provides blocking operations. - * @return A new instance of {@link McpSyncServer} configured with this builder's - * settings - */ - public McpSyncServer build() { - var tools = this.tools.stream().map(McpServerFeatures.SyncToolRegistration::toSpecification).toList(); - - var resources = this.resources.entrySet() - .stream() - .map(entry -> Map.entry(entry.getKey(), entry.getValue().toSpecification())) - .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); - - var prompts = this.prompts.entrySet() - .stream() - .map(entry -> Map.entry(entry.getKey(), entry.getValue().toSpecification())) - .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); - - var rootsChangeHandlers = this.rootsChangeConsumers.stream() - .map(consumer -> (BiConsumer>) (exchange, roots) -> consumer - .accept(roots)) - .toList(); - - McpServerFeatures.Sync syncFeatures = new McpServerFeatures.Sync(this.serverInfo, this.serverCapabilities, - tools, resources, this.resourceTemplates, prompts, rootsChangeHandlers); - - McpServerFeatures.Async asyncFeatures = McpServerFeatures.Async.fromSync(syncFeatures); - var asyncServer = new McpAsyncServer(this.transport, asyncFeatures); - - return new McpSyncServer(asyncServer); - } - - } - } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java index 5aeeadd77..8c110027c 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java @@ -10,7 +10,6 @@ import java.util.Map; import java.util.function.BiConsumer; import java.util.function.BiFunction; -import java.util.function.Function; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.util.Assert; @@ -423,272 +422,4 @@ public record SyncPromptSpecification(McpSchema.Prompt prompt, BiFunction promptHandler) { } - // --------------------------------------- - // Deprecated registrations - // --------------------------------------- - - /** - * Registration of a tool with its asynchronous handler function. Tools are the - * primary way for MCP servers to expose functionality to AI models. Each tool - * represents a specific capability, such as: - *
    - *
  • Performing calculations - *
  • Accessing external APIs - *
  • Querying databases - *
  • Manipulating files - *
  • Executing system commands - *
- * - *

- * Example tool registration:

{@code
-	 * new McpServerFeatures.AsyncToolRegistration(
-	 *     new Tool(
-	 *         "calculator",
-	 *         "Performs mathematical calculations",
-	 *         new JsonSchemaObject()
-	 *             .required("expression")
-	 *             .property("expression", JsonSchemaType.STRING)
-	 *     ),
-	 *     args -> {
-	 *         String expr = (String) args.get("expression");
-	 *         return Mono.just(new CallToolResult("Result: " + evaluate(expr)));
-	 *     }
-	 * )
-	 * }
- * - * @param tool The tool definition including name, description, and parameter schema - * @param call The function that implements the tool's logic, receiving arguments and - * returning results - * @deprecated This class is deprecated and will be removed in 0.9.0. Use - * {@link AsyncToolSpecification}. - */ - @Deprecated - public record AsyncToolRegistration(McpSchema.Tool tool, - Function, Mono> call) { - - static AsyncToolRegistration fromSync(SyncToolRegistration tool) { - // FIXME: This is temporary, proper validation should be implemented - if (tool == null) { - return null; - } - return new AsyncToolRegistration(tool.tool(), - map -> Mono.fromCallable(() -> tool.call().apply(map)).subscribeOn(Schedulers.boundedElastic())); - } - - public AsyncToolSpecification toSpecification() { - return new AsyncToolSpecification(tool(), (exchange, map) -> call.apply(map)); - } - } - - /** - * Registration of a resource with its asynchronous handler function. Resources - * provide context to AI models by exposing data such as: - *
    - *
  • File contents - *
  • Database records - *
  • API responses - *
  • System information - *
  • Application state - *
- * - *

- * Example resource registration:

{@code
-	 * new McpServerFeatures.AsyncResourceRegistration(
-	 *     new Resource("docs", "Documentation files", "text/markdown"),
-	 *     request -> {
-	 *         String content = readFile(request.getPath());
-	 *         return Mono.just(new ReadResourceResult(content));
-	 *     }
-	 * )
-	 * }
- * - * @param resource The resource definition including name, description, and MIME type - * @param readHandler The function that handles resource read requests - * @deprecated This class is deprecated and will be removed in 0.9.0. Use - * {@link AsyncResourceSpecification}. - */ - @Deprecated - public record AsyncResourceRegistration(McpSchema.Resource resource, - Function> readHandler) { - - static AsyncResourceRegistration fromSync(SyncResourceRegistration resource) { - // FIXME: This is temporary, proper validation should be implemented - if (resource == null) { - return null; - } - return new AsyncResourceRegistration(resource.resource(), - req -> Mono.fromCallable(() -> resource.readHandler().apply(req)) - .subscribeOn(Schedulers.boundedElastic())); - } - - public AsyncResourceSpecification toSpecification() { - return new AsyncResourceSpecification(resource(), (exchange, request) -> readHandler.apply(request)); - } - } - - /** - * Registration of a prompt template with its asynchronous handler function. Prompts - * provide structured templates for AI model interactions, supporting: - *
    - *
  • Consistent message formatting - *
  • Parameter substitution - *
  • Context injection - *
  • Response formatting - *
  • Instruction templating - *
- * - *

- * Example prompt registration:

{@code
-	 * new McpServerFeatures.AsyncPromptRegistration(
-	 *     new Prompt("analyze", "Code analysis template"),
-	 *     request -> {
-	 *         String code = request.getArguments().get("code");
-	 *         return Mono.just(new GetPromptResult(
-	 *             "Analyze this code:\n\n" + code + "\n\nProvide feedback on:"
-	 *         ));
-	 *     }
-	 * )
-	 * }
- * - * @param prompt The prompt definition including name and description - * @param promptHandler The function that processes prompt requests and returns - * formatted templates - * @deprecated This class is deprecated and will be removed in 0.9.0. Use - * {@link AsyncPromptSpecification}. - */ - @Deprecated - public record AsyncPromptRegistration(McpSchema.Prompt prompt, - Function> promptHandler) { - - static AsyncPromptRegistration fromSync(SyncPromptRegistration prompt) { - // FIXME: This is temporary, proper validation should be implemented - if (prompt == null) { - return null; - } - return new AsyncPromptRegistration(prompt.prompt(), - req -> Mono.fromCallable(() -> prompt.promptHandler().apply(req)) - .subscribeOn(Schedulers.boundedElastic())); - } - - public AsyncPromptSpecification toSpecification() { - return new AsyncPromptSpecification(prompt(), (exchange, request) -> promptHandler.apply(request)); - } - } - - /** - * Registration of a tool with its synchronous handler function. Tools are the primary - * way for MCP servers to expose functionality to AI models. Each tool represents a - * specific capability, such as: - *
    - *
  • Performing calculations - *
  • Accessing external APIs - *
  • Querying databases - *
  • Manipulating files - *
  • Executing system commands - *
- * - *

- * Example tool registration:

{@code
-	 * new McpServerFeatures.SyncToolRegistration(
-	 *     new Tool(
-	 *         "calculator",
-	 *         "Performs mathematical calculations",
-	 *         new JsonSchemaObject()
-	 *             .required("expression")
-	 *             .property("expression", JsonSchemaType.STRING)
-	 *     ),
-	 *     args -> {
-	 *         String expr = (String) args.get("expression");
-	 *         return new CallToolResult("Result: " + evaluate(expr));
-	 *     }
-	 * )
-	 * }
- * - * @param tool The tool definition including name, description, and parameter schema - * @param call The function that implements the tool's logic, receiving arguments and - * returning results - * @deprecated This class is deprecated and will be removed in 0.9.0. Use - * {@link SyncToolSpecification}. - */ - @Deprecated - public record SyncToolRegistration(McpSchema.Tool tool, - Function, McpSchema.CallToolResult> call) { - public SyncToolSpecification toSpecification() { - return new SyncToolSpecification(tool, (exchange, map) -> call.apply(map)); - } - } - - /** - * Registration of a resource with its synchronous handler function. Resources provide - * context to AI models by exposing data such as: - *
    - *
  • File contents - *
  • Database records - *
  • API responses - *
  • System information - *
  • Application state - *
- * - *

- * Example resource registration:

{@code
-	 * new McpServerFeatures.SyncResourceRegistration(
-	 *     new Resource("docs", "Documentation files", "text/markdown"),
-	 *     request -> {
-	 *         String content = readFile(request.getPath());
-	 *         return new ReadResourceResult(content);
-	 *     }
-	 * )
-	 * }
- * - * @param resource The resource definition including name, description, and MIME type - * @param readHandler The function that handles resource read requests - * @deprecated This class is deprecated and will be removed in 0.9.0. Use - * {@link SyncResourceSpecification}. - */ - @Deprecated - public record SyncResourceRegistration(McpSchema.Resource resource, - Function readHandler) { - public SyncResourceSpecification toSpecification() { - return new SyncResourceSpecification(resource, (exchange, request) -> readHandler.apply(request)); - } - } - - /** - * Registration of a prompt template with its synchronous handler function. Prompts - * provide structured templates for AI model interactions, supporting: - *
    - *
  • Consistent message formatting - *
  • Parameter substitution - *
  • Context injection - *
  • Response formatting - *
  • Instruction templating - *
- * - *

- * Example prompt registration:

{@code
-	 * new McpServerFeatures.SyncPromptRegistration(
-	 *     new Prompt("analyze", "Code analysis template"),
-	 *     request -> {
-	 *         String code = request.getArguments().get("code");
-	 *         return new GetPromptResult(
-	 *             "Analyze this code:\n\n" + code + "\n\nProvide feedback on:"
-	 *         );
-	 *     }
-	 * )
-	 * }
- * - * @param prompt The prompt definition including name and description - * @param promptHandler The function that processes prompt requests and returns - * formatted templates - * @deprecated This class is deprecated and will be removed in 0.9.0. Use - * {@link SyncPromptSpecification}. - */ - @Deprecated - public record SyncPromptRegistration(McpSchema.Prompt prompt, - Function promptHandler) { - public SyncPromptSpecification toSpecification() { - return new SyncPromptSpecification(prompt, (exchange, request) -> promptHandler.apply(request)); - } - } - } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java index 60662d98d..72eba8b86 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java @@ -65,40 +65,6 @@ public McpSyncServer(McpAsyncServer asyncServer) { this.asyncServer = asyncServer; } - /** - * Retrieves the list of all roots provided by the client. - * @return The list of roots - * @deprecated This method will be removed in 0.9.0. Use - * {@link McpSyncServerExchange#listRoots()}. - */ - @Deprecated - public McpSchema.ListRootsResult listRoots() { - return this.listRoots(null); - } - - /** - * Retrieves a paginated list of roots provided by the server. - * @param cursor Optional pagination cursor from a previous list request - * @return The list of roots - * @deprecated This method will be removed in 0.9.0. Use - * {@link McpSyncServerExchange#listRoots(String)}. - */ - @Deprecated - public McpSchema.ListRootsResult listRoots(String cursor) { - return this.asyncServer.listRoots(cursor).block(); - } - - /** - * Add a new tool handler. - * @param toolHandler The tool handler to add - * @deprecated This method will be removed in 0.9.0. Use - * {@link #addTool(McpServerFeatures.SyncToolSpecification)}. - */ - @Deprecated - public void addTool(McpServerFeatures.SyncToolRegistration toolHandler) { - this.asyncServer.addTool(McpServerFeatures.AsyncToolRegistration.fromSync(toolHandler)).block(); - } - /** * Add a new tool handler. * @param toolHandler The tool handler to add @@ -115,17 +81,6 @@ public void removeTool(String toolName) { this.asyncServer.removeTool(toolName).block(); } - /** - * Add a new resource handler. - * @param resourceHandler The resource handler to add - * @deprecated This method will be removed in 0.9.0. Use - * {@link #addResource(McpServerFeatures.SyncResourceSpecification)}. - */ - @Deprecated - public void addResource(McpServerFeatures.SyncResourceRegistration resourceHandler) { - this.asyncServer.addResource(McpServerFeatures.AsyncResourceRegistration.fromSync(resourceHandler)).block(); - } - /** * Add a new resource handler. * @param resourceHandler The resource handler to add @@ -142,17 +97,6 @@ public void removeResource(String resourceUri) { this.asyncServer.removeResource(resourceUri).block(); } - /** - * Add a new prompt handler. - * @param promptRegistration The prompt registration to add - * @deprecated This method will be removed in 0.9.0. Use - * {@link #addPrompt(McpServerFeatures.SyncPromptSpecification)}. - */ - @Deprecated - public void addPrompt(McpServerFeatures.SyncPromptRegistration promptRegistration) { - this.asyncServer.addPrompt(McpServerFeatures.AsyncPromptRegistration.fromSync(promptRegistration)).block(); - } - /** * Add a new prompt handler. * @param promptSpecification The prompt specification to add @@ -192,28 +136,6 @@ public McpSchema.Implementation getServerInfo() { return this.asyncServer.getServerInfo(); } - /** - * Get the client capabilities that define the supported features and functionality. - * @return The client capabilities - * @deprecated This method will be removed in 0.9.0. Use - * {@link McpSyncServerExchange#getClientCapabilities()}. - */ - @Deprecated - public ClientCapabilities getClientCapabilities() { - return this.asyncServer.getClientCapabilities(); - } - - /** - * Get the client implementation information. - * @return The client implementation details - * @deprecated This method will be removed in 0.9.0. Use - * {@link McpSyncServerExchange#getClientInfo()}. - */ - @Deprecated - public McpSchema.Implementation getClientInfo() { - return this.asyncServer.getClientInfo(); - } - /** * Notify clients that the list of available resources has changed. */ @@ -258,36 +180,4 @@ public McpAsyncServer getAsyncServer() { return this.asyncServer; } - /** - * Create a new message using the sampling capabilities of the client. The Model - * Context Protocol (MCP) provides a standardized way for servers to request LLM - * sampling ("completions" or "generations") from language models via clients. - * - *

- * This flow allows clients to maintain control over model access, selection, and - * permissions while enabling servers to leverage AI capabilities—with no server API - * keys necessary. Servers can request text or image-based interactions and optionally - * include context from MCP servers in their prompts. - * - *

- * Unlike its async counterpart, this method blocks until the message creation is - * complete, making it easier to use in synchronous code paths. - * @param createMessageRequest The request to create a new message - * @return The result of the message creation - * @throws McpError if the client has not been initialized or does not support - * sampling capabilities - * @throws McpError if the client does not support the createMessage method - * @see McpSchema.CreateMessageRequest - * @see McpSchema.CreateMessageResult - * @see Sampling - * Specification - * @deprecated This method will be removed in 0.9.0. Use - * {@link McpSyncServerExchange#createMessage(McpSchema.CreateMessageRequest)}. - */ - @Deprecated - public McpSchema.CreateMessageResult createMessage(McpSchema.CreateMessageRequest createMessageRequest) { - return this.asyncServer.createMessage(createMessageRequest).block(); - } - } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransport.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransport.java deleted file mode 100644 index fa5dcf1c1..000000000 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransport.java +++ /dev/null @@ -1,419 +0,0 @@ -/* -* Copyright 2024 - 2024 the original author or authors. -*/ -package io.modelcontextprotocol.server.transport; - -import java.io.BufferedReader; -import java.io.IOException; -import java.io.PrintWriter; -import java.util.Map; -import java.util.UUID; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.Function; - -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.ServerMcpTransport; -import jakarta.servlet.AsyncContext; -import jakarta.servlet.ServletException; -import jakarta.servlet.annotation.WebServlet; -import jakarta.servlet.http.HttpServlet; -import jakarta.servlet.http.HttpServletRequest; -import jakarta.servlet.http.HttpServletResponse; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import reactor.core.publisher.Mono; - -/** - * A Servlet-based implementation of the MCP HTTP with Server-Sent Events (SSE) transport - * specification. This implementation provides similar functionality to - * WebFluxSseServerTransport but uses the traditional Servlet API instead of WebFlux. - * - * @deprecated This class will be removed in 0.9.0. Use - * {@link HttpServletSseServerTransportProvider}. - * - *

- * The transport handles two types of endpoints: - *

    - *
  • SSE endpoint (/sse) - Establishes a long-lived connection for server-to-client - * events
  • - *
  • Message endpoint (configurable) - Handles client-to-server message requests
  • - *
- * - *

- * Features: - *

    - *
  • Asynchronous message handling using Servlet 6.0 async support
  • - *
  • Session management for multiple client connections
  • - *
  • Graceful shutdown support
  • - *
  • Error handling and response formatting
  • - *
- * @author Christian Tzolov - * @author Alexandros Pappas - * @see ServerMcpTransport - * @see HttpServlet - */ - -@WebServlet(asyncSupported = true) -@Deprecated -public class HttpServletSseServerTransport extends HttpServlet implements ServerMcpTransport { - - /** Logger for this class */ - private static final Logger logger = LoggerFactory.getLogger(HttpServletSseServerTransport.class); - - public static final String UTF_8 = "UTF-8"; - - public static final String APPLICATION_JSON = "application/json"; - - public static final String FAILED_TO_SEND_ERROR_RESPONSE = "Failed to send error response: {}"; - - /** Default endpoint path for SSE connections */ - public static final String DEFAULT_SSE_ENDPOINT = "/sse"; - - /** Event type for regular messages */ - public static final String MESSAGE_EVENT_TYPE = "message"; - - /** Event type for endpoint information */ - public static final String ENDPOINT_EVENT_TYPE = "endpoint"; - - /** JSON object mapper for serialization/deserialization */ - private final ObjectMapper objectMapper; - - /** The endpoint path for handling client messages */ - private final String messageEndpoint; - - /** The endpoint path for handling SSE connections */ - private final String sseEndpoint; - - /** Map of active client sessions, keyed by session ID */ - private final Map sessions = new ConcurrentHashMap<>(); - - /** Flag indicating if the transport is in the process of shutting down */ - private final AtomicBoolean isClosing = new AtomicBoolean(false); - - /** Handler for processing incoming messages */ - private Function, Mono> connectHandler; - - /** - * Creates a new HttpServletSseServerTransport instance with a custom SSE endpoint. - * @param objectMapper The JSON object mapper to use for message - * serialization/deserialization - * @param messageEndpoint The endpoint path where clients will send their messages - * @param sseEndpoint The endpoint path where clients will establish SSE connections - */ - public HttpServletSseServerTransport(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) { - this.objectMapper = objectMapper; - this.messageEndpoint = messageEndpoint; - this.sseEndpoint = sseEndpoint; - } - - /** - * Creates a new HttpServletSseServerTransport instance with the default SSE endpoint. - * @param objectMapper The JSON object mapper to use for message - * serialization/deserialization - * @param messageEndpoint The endpoint path where clients will send their messages - */ - public HttpServletSseServerTransport(ObjectMapper objectMapper, String messageEndpoint) { - this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT); - } - - /** - * Handles GET requests to establish SSE connections. - *

- * This method sets up a new SSE connection when a client connects to the SSE - * endpoint. It configures the response headers for SSE, creates a new session, and - * sends the initial endpoint information to the client. - * @param request The HTTP servlet request - * @param response The HTTP servlet response - * @throws ServletException If a servlet-specific error occurs - * @throws IOException If an I/O error occurs - */ - @Override - protected void doGet(HttpServletRequest request, HttpServletResponse response) - throws ServletException, IOException { - - String pathInfo = request.getPathInfo(); - if (!sseEndpoint.equals(pathInfo)) { - response.sendError(HttpServletResponse.SC_NOT_FOUND); - return; - } - - if (isClosing.get()) { - response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE, "Server is shutting down"); - return; - } - - response.setContentType("text/event-stream"); - response.setCharacterEncoding(UTF_8); - response.setHeader("Cache-Control", "no-cache"); - response.setHeader("Connection", "keep-alive"); - response.setHeader("Access-Control-Allow-Origin", "*"); - - String sessionId = UUID.randomUUID().toString(); - AsyncContext asyncContext = request.startAsync(); - asyncContext.setTimeout(0); - - PrintWriter writer = response.getWriter(); - ClientSession session = new ClientSession(sessionId, asyncContext, writer); - this.sessions.put(sessionId, session); - - // Send initial endpoint event - this.sendEvent(writer, ENDPOINT_EVENT_TYPE, messageEndpoint); - } - - /** - * Handles POST requests for client messages. - *

- * This method processes incoming messages from clients, routes them through the - * connect handler if configured, and sends back the appropriate response. It handles - * error cases and formats error responses according to the MCP specification. - * @param request The HTTP servlet request - * @param response The HTTP servlet response - * @throws ServletException If a servlet-specific error occurs - * @throws IOException If an I/O error occurs - */ - @Override - protected void doPost(HttpServletRequest request, HttpServletResponse response) - throws ServletException, IOException { - - if (isClosing.get()) { - response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE, "Server is shutting down"); - return; - } - - String pathInfo = request.getPathInfo(); - if (!messageEndpoint.equals(pathInfo)) { - response.sendError(HttpServletResponse.SC_NOT_FOUND); - return; - } - - try { - BufferedReader reader = request.getReader(); - StringBuilder body = new StringBuilder(); - String line; - while ((line = reader.readLine()) != null) { - body.append(line); - } - - McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body.toString()); - - if (connectHandler != null) { - connectHandler.apply(Mono.just(message)).subscribe(responseMessage -> { - try { - response.setContentType(APPLICATION_JSON); - response.setCharacterEncoding(UTF_8); - String jsonResponse = objectMapper.writeValueAsString(responseMessage); - PrintWriter writer = response.getWriter(); - writer.write(jsonResponse); - writer.flush(); - } - catch (Exception e) { - logger.error("Error sending response: {}", e.getMessage()); - try { - response.sendError(HttpServletResponse.SC_INTERNAL_SERVER_ERROR, - "Error processing response: " + e.getMessage()); - } - catch (IOException ex) { - logger.error(FAILED_TO_SEND_ERROR_RESPONSE, ex.getMessage()); - } - } - }, error -> { - try { - logger.error("Error processing message: {}", error.getMessage()); - McpError mcpError = new McpError(error.getMessage()); - response.setContentType(APPLICATION_JSON); - response.setCharacterEncoding(UTF_8); - response.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR); - String jsonError = objectMapper.writeValueAsString(mcpError); - PrintWriter writer = response.getWriter(); - writer.write(jsonError); - writer.flush(); - } - catch (IOException e) { - logger.error(FAILED_TO_SEND_ERROR_RESPONSE, e.getMessage()); - try { - response.sendError(HttpServletResponse.SC_INTERNAL_SERVER_ERROR, - "Error sending error response: " + e.getMessage()); - } - catch (IOException ex) { - logger.error(FAILED_TO_SEND_ERROR_RESPONSE, ex.getMessage()); - } - } - }); - } - else { - response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE, "No message handler configured"); - } - } - catch (Exception e) { - logger.error("Invalid message format: {}", e.getMessage()); - try { - McpError mcpError = new McpError("Invalid message format: " + e.getMessage()); - response.setContentType(APPLICATION_JSON); - response.setCharacterEncoding(UTF_8); - response.setStatus(HttpServletResponse.SC_BAD_REQUEST); - String jsonError = objectMapper.writeValueAsString(mcpError); - PrintWriter writer = response.getWriter(); - writer.write(jsonError); - writer.flush(); - } - catch (IOException ex) { - logger.error(FAILED_TO_SEND_ERROR_RESPONSE, ex.getMessage()); - response.sendError(HttpServletResponse.SC_BAD_REQUEST, "Invalid message format"); - } - } - } - - /** - * Sets up the message handler for processing client requests. - * @param handler The function to process incoming messages and produce responses - * @return A Mono that completes when the handler is set up - */ - @Override - public Mono connect(Function, Mono> handler) { - this.connectHandler = handler; - return Mono.empty(); - } - - /** - * Broadcasts a message to all connected clients. - *

- * This method serializes the message and sends it to all active client sessions. If a - * client is disconnected, its session is removed. - * @param message The message to broadcast - * @return A Mono that completes when the message has been sent to all clients - */ - @Override - public Mono sendMessage(McpSchema.JSONRPCMessage message) { - if (sessions.isEmpty()) { - logger.debug("No active sessions to broadcast message to"); - return Mono.empty(); - } - - return Mono.create(sink -> { - try { - String jsonText = objectMapper.writeValueAsString(message); - - sessions.values().forEach(session -> { - try { - this.sendEvent(session.writer, MESSAGE_EVENT_TYPE, jsonText); - } - catch (IOException e) { - logger.error("Failed to send message to session {}: {}", session.id, e.getMessage()); - removeSession(session); - } - }); - - sink.success(); - } - catch (Exception e) { - logger.error("Failed to process message: {}", e.getMessage()); - sink.error(new McpError("Failed to process message: " + e.getMessage())); - } - }); - } - - /** - * Closes the transport. - *

- * This implementation delegates to the super class's close method. - */ - @Override - public void close() { - ServerMcpTransport.super.close(); - } - - /** - * Unmarshals data from one type to another using the object mapper. - * @param The target type - * @param data The source data - * @param typeRef The type reference for the target type - * @return The unmarshaled data - */ - @Override - public T unmarshalFrom(Object data, TypeReference typeRef) { - return objectMapper.convertValue(data, typeRef); - } - - /** - * Initiates a graceful shutdown of the transport. - *

- * This method marks the transport as closing and closes all active client sessions. - * New connection attempts will be rejected during shutdown. - * @return A Mono that completes when all sessions have been closed - */ - @Override - public Mono closeGracefully() { - isClosing.set(true); - logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size()); - - return Mono.create(sink -> { - sessions.values().forEach(this::removeSession); - sink.success(); - }); - } - - /** - * Sends an SSE event to a client. - * @param writer The writer to send the event through - * @param eventType The type of event (message or endpoint) - * @param data The event data - * @throws IOException If an error occurs while writing the event - */ - private void sendEvent(PrintWriter writer, String eventType, String data) throws IOException { - writer.write("event: " + eventType + "\n"); - writer.write("data: " + data + "\n\n"); - writer.flush(); - - if (writer.checkError()) { - throw new IOException("Client disconnected"); - } - } - - /** - * Removes a client session and completes its async context. - * @param session The session to remove - */ - private void removeSession(ClientSession session) { - sessions.remove(session.id); - session.asyncContext.complete(); - } - - /** - * Represents a client connection session. - *

- * This class holds the necessary information about a client's SSE connection, - * including its ID, async context, and output writer. - */ - private static class ClientSession { - - private final String id; - - private final AsyncContext asyncContext; - - private final PrintWriter writer; - - ClientSession(String id, AsyncContext asyncContext, PrintWriter writer) { - this.id = id; - this.asyncContext = asyncContext; - this.writer = writer; - } - - } - - /** - * Cleans up resources when the servlet is being destroyed. - *

- * This method ensures a graceful shutdown by closing all client connections before - * calling the parent's destroy method. - */ - @Override - public void destroy() { - closeGracefully().block(); - super.destroy(); - } - -} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransport.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransport.java deleted file mode 100644 index 78264ca32..000000000 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransport.java +++ /dev/null @@ -1,259 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.server.transport; - -import java.io.BufferedReader; -import java.io.IOException; -import java.io.InputStream; -import java.io.InputStreamReader; -import java.io.OutputStream; -import java.nio.charset.StandardCharsets; -import java.util.concurrent.Executors; -import java.util.function.Function; - -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; -import io.modelcontextprotocol.spec.ServerMcpTransport; -import io.modelcontextprotocol.util.Assert; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; -import reactor.core.publisher.Sinks; -import reactor.core.scheduler.Scheduler; -import reactor.core.scheduler.Schedulers; - -/** - * Implementation of the MCP Stdio transport for servers that communicates using standard - * input/output streams. Messages are exchanged as newline-delimited JSON-RPC messages - * over stdin/stdout, with errors and debug information sent to stderr. - * - * @author Christian Tzolov - * @deprecated This method will be removed in 0.9.0. Use - * {@link io.modelcontextprotocol.server.transport.StdioServerTransportProvider} instead. - */ -@Deprecated -public class StdioServerTransport implements ServerMcpTransport { - - private static final Logger logger = LoggerFactory.getLogger(StdioServerTransport.class); - - private final Sinks.Many inboundSink; - - private final Sinks.Many outboundSink; - - private ObjectMapper objectMapper; - - /** Scheduler for handling inbound messages */ - private Scheduler inboundScheduler; - - /** Scheduler for handling outbound messages */ - private Scheduler outboundScheduler; - - private volatile boolean isClosing = false; - - private final InputStream inputStream; - - private final OutputStream outputStream; - - private final Sinks.One inboundReady = Sinks.one(); - - private final Sinks.One outboundReady = Sinks.one(); - - /** - * Creates a new StdioServerTransport with a default ObjectMapper and System streams. - */ - public StdioServerTransport() { - this(new ObjectMapper()); - } - - /** - * Creates a new StdioServerTransport with the specified ObjectMapper and System - * streams. - * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization - */ - public StdioServerTransport(ObjectMapper objectMapper) { - - Assert.notNull(objectMapper, "The ObjectMapper can not be null"); - - this.inboundSink = Sinks.many().unicast().onBackpressureBuffer(); - this.outboundSink = Sinks.many().unicast().onBackpressureBuffer(); - - this.objectMapper = objectMapper; - this.inputStream = System.in; - this.outputStream = System.out; - - // Use bounded schedulers for better resource management - this.inboundScheduler = Schedulers.fromExecutorService(Executors.newSingleThreadExecutor(), "inbound"); - this.outboundScheduler = Schedulers.fromExecutorService(Executors.newSingleThreadExecutor(), "outbound"); - } - - @Override - public Mono connect(Function, Mono> handler) { - return Mono.fromRunnable(() -> { - handleIncomingMessages(handler); - - // Start threads - startInboundProcessing(); - startOutboundProcessing(); - }).subscribeOn(Schedulers.boundedElastic()); - } - - private void handleIncomingMessages(Function, Mono> inboundMessageHandler) { - this.inboundSink.asFlux() - .flatMap(message -> Mono.just(message) - .transform(inboundMessageHandler) - .contextWrite(ctx -> ctx.put("observation", "myObservation"))) - .doOnTerminate(() -> { - // The outbound processing will dispose its scheduler upon completion - this.outboundSink.tryEmitComplete(); - this.inboundScheduler.dispose(); - }) - .subscribe(); - } - - @Override - public Mono sendMessage(JSONRPCMessage message) { - return Mono.zip(inboundReady.asMono(), outboundReady.asMono()).then(Mono.defer(() -> { - if (this.outboundSink.tryEmitNext(message).isSuccess()) { - return Mono.empty(); - } - else { - return Mono.error(new RuntimeException("Failed to enqueue message")); - } - })); - } - - /** - * Starts the inbound processing thread that reads JSON-RPC messages from stdin. - * Messages are deserialized and emitted to the inbound sink. - */ - private void startInboundProcessing() { - this.inboundScheduler.schedule(() -> { - inboundReady.tryEmitValue(null); - BufferedReader reader = null; - try { - reader = new BufferedReader(new InputStreamReader(inputStream)); - while (!isClosing) { - try { - String line = reader.readLine(); - if (line == null || isClosing) { - break; - } - - logger.debug("Received JSON message: {}", line); - - try { - JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.objectMapper, line); - if (!this.inboundSink.tryEmitNext(message).isSuccess()) { - logIfNotClosing("Failed to enqueue message"); - break; - } - } - catch (Exception e) { - logIfNotClosing("Error processing inbound message", e); - break; - } - } - catch (IOException e) { - logIfNotClosing("Error reading from stdin", e); - break; - } - } - } - catch (Exception e) { - logIfNotClosing("Error in inbound processing", e); - } - finally { - isClosing = true; - inboundSink.tryEmitComplete(); - } - }); - } - - /** - * Starts the outbound processing thread that writes JSON-RPC messages to stdout. - * Messages are serialized to JSON and written with a newline delimiter. - */ - private void startOutboundProcessing() { - Function, Flux> outboundConsumer = messages -> messages // @formatter:off - .doOnSubscribe(subscription -> outboundReady.tryEmitValue(null)) - .publishOn(outboundScheduler) - .handle((message, sink) -> { - if (message != null && !isClosing) { - try { - String jsonMessage = objectMapper.writeValueAsString(message); - // Escape any embedded newlines in the JSON message as per spec - jsonMessage = jsonMessage.replace("\r\n", "\\n").replace("\n", "\\n").replace("\r", "\\n"); - - synchronized (outputStream) { - outputStream.write(jsonMessage.getBytes(StandardCharsets.UTF_8)); - outputStream.write("\n".getBytes(StandardCharsets.UTF_8)); - outputStream.flush(); - } - sink.next(message); - } - catch (IOException e) { - if (!isClosing) { - logger.error("Error writing message", e); - sink.error(new RuntimeException(e)); - } - else { - logger.debug("Stream closed during shutdown", e); - } - } - } - else if (isClosing) { - sink.complete(); - } - }) - .doOnComplete(() -> { - isClosing = true; - outboundScheduler.dispose(); - }) - .doOnError(e -> { - if (!isClosing) { - logger.error("Error in outbound processing", e); - isClosing = true; - outboundScheduler.dispose(); - } - }) - .map(msg -> (JSONRPCMessage) msg); - - outboundConsumer.apply(outboundSink.asFlux()).subscribe(); - } // @formatter:on - - @Override - public Mono closeGracefully() { - return Mono.defer(() -> { - isClosing = true; - logger.debug("Initiating graceful shutdown"); - // Completing the inbound causes the outbound to be completed as well, so - // we only close the inbound. - inboundSink.tryEmitComplete(); - logger.debug("Graceful shutdown complete"); - return Mono.empty(); - }).subscribeOn(Schedulers.boundedElastic()); - } - - @Override - public T unmarshalFrom(Object data, TypeReference typeRef) { - return this.objectMapper.convertValue(data, typeRef); - } - - private void logIfNotClosing(String message, Exception e) { - if (!this.isClosing) { - logger.error(message, e); - } - } - - private void logIfNotClosing(String message) { - if (!this.isClosing) { - logger.error(message); - } - } - -} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/ClientMcpTransport.java b/mcp/src/main/java/io/modelcontextprotocol/spec/ClientMcpTransport.java deleted file mode 100644 index 8464b6ae7..000000000 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/ClientMcpTransport.java +++ /dev/null @@ -1,15 +0,0 @@ -/* -* Copyright 2024 - 2024 the original author or authors. -*/ -package io.modelcontextprotocol.spec; - -/** - * Marker interface for the client-side MCP transport. - * - * @author Christian Tzolov - * @deprecated This class will be removed in 0.9.0. Use {@link McpClientTransport}. - */ -@Deprecated -public interface ClientMcpTransport extends McpTransport { - -} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpSession.java deleted file mode 100644 index 83de4c094..000000000 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpSession.java +++ /dev/null @@ -1,291 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.spec; - -import java.time.Duration; -import java.util.Map; -import java.util.UUID; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.atomic.AtomicLong; - -import com.fasterxml.jackson.core.type.TypeReference; -import io.modelcontextprotocol.util.Assert; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import reactor.core.Disposable; -import reactor.core.publisher.Mono; -import reactor.core.publisher.MonoSink; - -/** - * Default implementation of the MCP (Model Context Protocol) session that manages - * bidirectional JSON-RPC communication between clients and servers. This implementation - * follows the MCP specification for message exchange and transport handling. - * - *

- * The session manages: - *

    - *
  • Request/response handling with unique message IDs
  • - *
  • Notification processing
  • - *
  • Message timeout management
  • - *
  • Transport layer abstraction
  • - *
- * - * @author Christian Tzolov - * @author Dariusz Jędrzejczyk - * @deprecated This method will be removed in 0.9.0. Use {@link McpClientSession} instead - */ -@Deprecated - -public class DefaultMcpSession implements McpSession { - - /** Logger for this class */ - private static final Logger logger = LoggerFactory.getLogger(DefaultMcpSession.class); - - /** Duration to wait for request responses before timing out */ - private final Duration requestTimeout; - - /** Transport layer implementation for message exchange */ - private final McpTransport transport; - - /** Map of pending responses keyed by request ID */ - private final ConcurrentHashMap> pendingResponses = new ConcurrentHashMap<>(); - - /** Map of request handlers keyed by method name */ - private final ConcurrentHashMap> requestHandlers = new ConcurrentHashMap<>(); - - /** Map of notification handlers keyed by method name */ - private final ConcurrentHashMap notificationHandlers = new ConcurrentHashMap<>(); - - /** Session-specific prefix for request IDs */ - private final String sessionPrefix = UUID.randomUUID().toString().substring(0, 8); - - /** Atomic counter for generating unique request IDs */ - private final AtomicLong requestCounter = new AtomicLong(0); - - private final Disposable connection; - - /** - * Functional interface for handling incoming JSON-RPC requests. Implementations - * should process the request parameters and return a response. - * - * @param Response type - */ - @FunctionalInterface - public interface RequestHandler { - - /** - * Handles an incoming request with the given parameters. - * @param params The request parameters - * @return A Mono containing the response object - */ - Mono handle(Object params); - - } - - /** - * Functional interface for handling incoming JSON-RPC notifications. Implementations - * should process the notification parameters without returning a response. - */ - @FunctionalInterface - public interface NotificationHandler { - - /** - * Handles an incoming notification with the given parameters. - * @param params The notification parameters - * @return A Mono that completes when the notification is processed - */ - Mono handle(Object params); - - } - - /** - * Creates a new DefaultMcpSession with the specified configuration and handlers. - * @param requestTimeout Duration to wait for responses - * @param transport Transport implementation for message exchange - * @param requestHandlers Map of method names to request handlers - * @param notificationHandlers Map of method names to notification handlers - */ - public DefaultMcpSession(Duration requestTimeout, McpTransport transport, - Map> requestHandlers, Map notificationHandlers) { - - Assert.notNull(requestTimeout, "The requstTimeout can not be null"); - Assert.notNull(transport, "The transport can not be null"); - Assert.notNull(requestHandlers, "The requestHandlers can not be null"); - Assert.notNull(notificationHandlers, "The notificationHandlers can not be null"); - - this.requestTimeout = requestTimeout; - this.transport = transport; - this.requestHandlers.putAll(requestHandlers); - this.notificationHandlers.putAll(notificationHandlers); - - // TODO: consider mono.transformDeferredContextual where the Context contains - // the - // Observation associated with the individual message - it can be used to - // create child Observation and emit it together with the message to the - // consumer - this.connection = this.transport.connect(mono -> mono.doOnNext(message -> { - if (message instanceof McpSchema.JSONRPCResponse response) { - logger.debug("Received Response: {}", response); - var sink = pendingResponses.remove(response.id()); - if (sink == null) { - logger.warn("Unexpected response for unkown id {}", response.id()); - } - else { - sink.success(response); - } - } - else if (message instanceof McpSchema.JSONRPCRequest request) { - logger.debug("Received request: {}", request); - handleIncomingRequest(request).subscribe(response -> transport.sendMessage(response).subscribe(), - error -> { - var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), - null, new McpSchema.JSONRPCResponse.JSONRPCError( - McpSchema.ErrorCodes.INTERNAL_ERROR, error.getMessage(), null)); - transport.sendMessage(errorResponse).subscribe(); - }); - } - else if (message instanceof McpSchema.JSONRPCNotification notification) { - logger.debug("Received notification: {}", notification); - handleIncomingNotification(notification).subscribe(null, - error -> logger.error("Error handling notification: {}", error.getMessage())); - } - })).subscribe(); - } - - /** - * Handles an incoming JSON-RPC request by routing it to the appropriate handler. - * @param request The incoming JSON-RPC request - * @return A Mono containing the JSON-RPC response - */ - private Mono handleIncomingRequest(McpSchema.JSONRPCRequest request) { - return Mono.defer(() -> { - var handler = this.requestHandlers.get(request.method()); - if (handler == null) { - MethodNotFoundError error = getMethodNotFoundError(request.method()); - return Mono.just(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, - new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.METHOD_NOT_FOUND, - error.message(), error.data()))); - } - - return handler.handle(request.params()) - .map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), result, null)) - .onErrorResume(error -> Mono.just(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), - null, new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, - error.getMessage(), null)))); // TODO: add error message - // through the data field - }); - } - - record MethodNotFoundError(String method, String message, Object data) { - } - - public static MethodNotFoundError getMethodNotFoundError(String method) { - switch (method) { - case McpSchema.METHOD_ROOTS_LIST: - return new MethodNotFoundError(method, "Roots not supported", - Map.of("reason", "Client does not have roots capability")); - default: - return new MethodNotFoundError(method, "Method not found: " + method, null); - } - } - - /** - * Handles an incoming JSON-RPC notification by routing it to the appropriate handler. - * @param notification The incoming JSON-RPC notification - * @return A Mono that completes when the notification is processed - */ - private Mono handleIncomingNotification(McpSchema.JSONRPCNotification notification) { - return Mono.defer(() -> { - var handler = notificationHandlers.get(notification.method()); - if (handler == null) { - logger.error("No handler registered for notification method: {}", notification.method()); - return Mono.empty(); - } - return handler.handle(notification.params()); - }); - } - - /** - * Generates a unique request ID in a non-blocking way. Combines a session-specific - * prefix with an atomic counter to ensure uniqueness. - * @return A unique request ID string - */ - private String generateRequestId() { - return this.sessionPrefix + "-" + this.requestCounter.getAndIncrement(); - } - - /** - * Sends a JSON-RPC request and returns the response. - * @param The expected response type - * @param method The method name to call - * @param requestParams The request parameters - * @param typeRef Type reference for response deserialization - * @return A Mono containing the response - */ - @Override - public Mono sendRequest(String method, Object requestParams, TypeReference typeRef) { - String requestId = this.generateRequestId(); - - return Mono.create(sink -> { - this.pendingResponses.put(requestId, sink); - McpSchema.JSONRPCRequest jsonrpcRequest = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, method, - requestId, requestParams); - this.transport.sendMessage(jsonrpcRequest) - // TODO: It's most efficient to create a dedicated Subscriber here - .subscribe(v -> { - }, error -> { - this.pendingResponses.remove(requestId); - sink.error(error); - }); - }).timeout(this.requestTimeout).handle((jsonRpcResponse, sink) -> { - if (jsonRpcResponse.error() != null) { - sink.error(new McpError(jsonRpcResponse.error())); - } - else { - if (typeRef.getType().equals(Void.class)) { - sink.complete(); - } - else { - sink.next(this.transport.unmarshalFrom(jsonRpcResponse.result(), typeRef)); - } - } - }); - } - - /** - * Sends a JSON-RPC notification. - * @param method The method name for the notification - * @param params The notification parameters - * @return A Mono that completes when the notification is sent - */ - @Override - public Mono sendNotification(String method, Map params) { - McpSchema.JSONRPCNotification jsonrpcNotification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, - method, params); - return this.transport.sendMessage(jsonrpcNotification); - } - - /** - * Closes the session gracefully, allowing pending operations to complete. - * @return A Mono that completes when the session is closed - */ - @Override - public Mono closeGracefully() { - return Mono.defer(() -> { - this.connection.dispose(); - return transport.closeGracefully(); - }); - } - - /** - * Closes the session immediately, potentially interrupting pending operations. - */ - @Override - public void close() { - this.connection.dispose(); - transport.close(); - } - -} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java index 6657e3622..e29646e6a 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java @@ -44,7 +44,7 @@ public class McpClientSession implements McpSession { private final Duration requestTimeout; /** Transport layer implementation for message exchange */ - private final McpTransport transport; + private final McpClientTransport transport; /** Map of pending responses keyed by request ID */ private final ConcurrentHashMap> pendingResponses = new ConcurrentHashMap<>(); @@ -104,7 +104,7 @@ public interface NotificationHandler { * @param requestHandlers Map of method names to request handlers * @param notificationHandlers Map of method names to notification handlers */ - public McpClientSession(Duration requestTimeout, McpTransport transport, + public McpClientSession(Duration requestTimeout, McpClientTransport transport, Map> requestHandlers, Map notificationHandlers) { Assert.notNull(requestTimeout, "The requstTimeout can not be null"); diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java index 458979651..f29091248 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java @@ -13,9 +13,8 @@ * @author Christian Tzolov * @author Dariusz Jędrzejczyk */ -public interface McpClientTransport extends ClientMcpTransport { +public interface McpClientTransport extends McpTransport { - @Override Mono connect(Function, Mono> handler); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransport.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransport.java index f698d8789..40d9ba7ac 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransport.java @@ -4,8 +4,6 @@ package io.modelcontextprotocol.spec; -import java.util.function.Function; - import com.fasterxml.jackson.core.type.TypeReference; import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; import reactor.core.publisher.Mono; @@ -39,21 +37,6 @@ */ public interface McpTransport { - /** - * Initializes and starts the transport connection. - * - *

- * This method should be called before any message exchange can occur. It sets up the - * necessary resources and establishes the connection to the server. - *

- * @deprecated This is only relevant for client-side transports and will be removed - * from this interface in 0.9.0. - */ - @Deprecated - default Mono connect(Function, Mono> handler) { - return Mono.empty(); - } - /** * Closes the transport connection and releases any associated resources. * diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/ServerMcpTransport.java b/mcp/src/main/java/io/modelcontextprotocol/spec/ServerMcpTransport.java deleted file mode 100644 index 704daee0f..000000000 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/ServerMcpTransport.java +++ /dev/null @@ -1,15 +0,0 @@ -/* -* Copyright 2024 - 2024 the original author or authors. -*/ -package io.modelcontextprotocol.spec; - -/** - * Marker interface for the server-side MCP transport. - * - * @author Christian Tzolov - * @deprecated This class will be removed in 0.9.0. Use {@link McpServerTransport}. - */ -@Deprecated -public interface ServerMcpTransport extends McpTransport { - -} diff --git a/mcp/src/test/java/io/modelcontextprotocol/MockMcpTransport.java b/mcp/src/test/java/io/modelcontextprotocol/MockMcpClientTransport.java similarity index 84% rename from mcp/src/test/java/io/modelcontextprotocol/MockMcpTransport.java rename to mcp/src/test/java/io/modelcontextprotocol/MockMcpClientTransport.java index 12f30d12f..482d0aac6 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/MockMcpTransport.java +++ b/mcp/src/test/java/io/modelcontextprotocol/MockMcpClientTransport.java @@ -13,30 +13,28 @@ import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.ServerMcpTransport; import io.modelcontextprotocol.spec.McpSchema.JSONRPCNotification; import io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest; import reactor.core.publisher.Mono; import reactor.core.publisher.Sinks; /** - * A mock implementation of the {@link McpClientTransport} and {@link ServerMcpTransport} - * interfaces. + * A mock implementation of the {@link McpClientTransport} interfaces. */ -public class MockMcpTransport implements McpClientTransport, ServerMcpTransport { +public class MockMcpClientTransport implements McpClientTransport { private final Sinks.Many inbound = Sinks.many().unicast().onBackpressureBuffer(); private final List sent = new ArrayList<>(); - private final BiConsumer interceptor; + private final BiConsumer interceptor; - public MockMcpTransport() { + public MockMcpClientTransport() { this((t, msg) -> { }); } - public MockMcpTransport(BiConsumer interceptor) { + public MockMcpClientTransport(BiConsumer interceptor) { this.interceptor = interceptor; } diff --git a/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransport.java b/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransport.java new file mode 100644 index 000000000..4be680e11 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransport.java @@ -0,0 +1,66 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol; + +import java.util.ArrayList; +import java.util.List; +import java.util.function.BiConsumer; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.JSONRPCNotification; +import io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest; +import io.modelcontextprotocol.spec.McpServerTransport; +import reactor.core.publisher.Mono; + +/** + * A mock implementation of the {@link McpServerTransport} interfaces. + */ +public class MockMcpServerTransport implements McpServerTransport { + + private final List sent = new ArrayList<>(); + + private final BiConsumer interceptor; + + public MockMcpServerTransport() { + this((t, msg) -> { + }); + } + + public MockMcpServerTransport(BiConsumer interceptor) { + this.interceptor = interceptor; + } + + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message) { + sent.add(message); + interceptor.accept(this, message); + return Mono.empty(); + } + + public McpSchema.JSONRPCRequest getLastSentMessageAsRequest() { + return (JSONRPCRequest) getLastSentMessage(); + } + + public McpSchema.JSONRPCNotification getLastSentMessageAsNotification() { + return (JSONRPCNotification) getLastSentMessage(); + } + + public McpSchema.JSONRPCMessage getLastSentMessage() { + return !sent.isEmpty() ? sent.get(sent.size() - 1) : null; + } + + @Override + public Mono closeGracefully() { + return Mono.empty(); + } + + @Override + public T unmarshalFrom(Object data, TypeReference typeRef) { + return new ObjectMapper().convertValue(data, typeRef); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java b/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java new file mode 100644 index 000000000..3fb19180b --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java @@ -0,0 +1,63 @@ +/* +* Copyright 2025 - 2025 the original author or authors. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* https://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +package io.modelcontextprotocol; + +import java.util.Map; + +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.spec.McpServerSession.Factory; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import reactor.core.publisher.Mono; + +/** + * @author Christian Tzolov + */ +public class MockMcpServerTransportProvider implements McpServerTransportProvider { + + private McpServerSession session; + + private final MockMcpServerTransport transport; + + public MockMcpServerTransportProvider(MockMcpServerTransport transport) { + this.transport = transport; + } + + public MockMcpServerTransport getTransport() { + return transport; + } + + @Override + public void setSessionFactory(Factory sessionFactory) { + + session = sessionFactory.create(transport); + } + + @Override + public Mono notifyClients(String method, Map params) { + return session.sendNotification(method, params); + } + + @Override + public Mono closeGracefully() { + return session.closeGracefully(); + } + + public void simulateIncomingMessage(McpSchema.JSONRPCMessage message) { + session.handle(message).subscribe(); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java index b1e82b748..4510b1529 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java @@ -12,7 +12,7 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.MockMcpTransport; +import io.modelcontextprotocol.MockMcpClientTransport; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; @@ -34,16 +34,16 @@ class McpAsyncClientResponseHandlerTests { .resources(true, true) // Enable both resources and resource templates .build(); - private static MockMcpTransport initializationEnabledTransport() { + private static MockMcpClientTransport initializationEnabledTransport() { return initializationEnabledTransport(SERVER_CAPABILITIES, SERVER_INFO); } - private static MockMcpTransport initializationEnabledTransport(McpSchema.ServerCapabilities mockServerCapabilities, - McpSchema.Implementation mockServerInfo) { + private static MockMcpClientTransport initializationEnabledTransport( + McpSchema.ServerCapabilities mockServerCapabilities, McpSchema.Implementation mockServerInfo) { McpSchema.InitializeResult mockInitResult = new McpSchema.InitializeResult(McpSchema.LATEST_PROTOCOL_VERSION, mockServerCapabilities, mockServerInfo, "Test instructions"); - return new MockMcpTransport((t, message) -> { + return new MockMcpClientTransport((t, message) -> { if (message instanceof McpSchema.JSONRPCRequest r && METHOD_INITIALIZE.equals(r.method())) { McpSchema.JSONRPCResponse initResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, r.id(), mockInitResult, null); @@ -59,7 +59,7 @@ void testSuccessfulInitialization() { .tools(false) .resources(true, true) // Enable both resources and resource templates .build(); - MockMcpTransport transport = initializationEnabledTransport(serverCapabilities, serverInfo); + MockMcpClientTransport transport = initializationEnabledTransport(serverCapabilities, serverInfo); McpAsyncClient asyncMcpClient = McpClient.async(transport).build(); // Verify client is not initialized initially @@ -91,7 +91,7 @@ void testSuccessfulInitialization() { @Test void testToolsChangeNotificationHandling() throws JsonProcessingException { - MockMcpTransport transport = initializationEnabledTransport(); + MockMcpClientTransport transport = initializationEnabledTransport(); // Create a list to store received tools for verification List receivedTools = new ArrayList<>(); @@ -134,7 +134,7 @@ void testToolsChangeNotificationHandling() throws JsonProcessingException { @Test void testRootsListRequestHandling() { - MockMcpTransport transport = initializationEnabledTransport(); + MockMcpClientTransport transport = initializationEnabledTransport(); McpAsyncClient asyncMcpClient = McpClient.async(transport) .roots(new Root("file:///test/path", "test-root")) @@ -162,7 +162,7 @@ void testRootsListRequestHandling() { @Test void testResourcesChangeNotificationHandling() { - MockMcpTransport transport = initializationEnabledTransport(); + MockMcpClientTransport transport = initializationEnabledTransport(); // Create a list to store received resources for verification List receivedResources = new ArrayList<>(); @@ -208,7 +208,7 @@ void testResourcesChangeNotificationHandling() { @Test void testPromptsChangeNotificationHandling() { - MockMcpTransport transport = initializationEnabledTransport(); + MockMcpClientTransport transport = initializationEnabledTransport(); // Create a list to store received prompts for verification List receivedPrompts = new ArrayList<>(); @@ -252,7 +252,7 @@ void testPromptsChangeNotificationHandling() { @Test void testSamplingCreateMessageRequestHandling() { - MockMcpTransport transport = initializationEnabledTransport(); + MockMcpClientTransport transport = initializationEnabledTransport(); // Create a test sampling handler that echoes back the input Function> samplingHandler = request -> { @@ -306,7 +306,7 @@ void testSamplingCreateMessageRequestHandling() { @Test void testSamplingCreateMessageRequestHandlingWithoutCapability() { - MockMcpTransport transport = initializationEnabledTransport(); + MockMcpClientTransport transport = initializationEnabledTransport(); // Create client without sampling capability McpAsyncClient asyncMcpClient = McpClient.async(transport) @@ -340,7 +340,7 @@ void testSamplingCreateMessageRequestHandlingWithoutCapability() { @Test void testSamplingCreateMessageRequestHandlingWithNullHandler() { - MockMcpTransport transport = new MockMcpTransport(); + MockMcpClientTransport transport = new MockMcpClientTransport(); // Create client with sampling capability but null handler assertThatThrownBy( diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/McpClientProtocolVersionTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/McpClientProtocolVersionTests.java index 58e486e19..bf4738496 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/McpClientProtocolVersionTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/McpClientProtocolVersionTests.java @@ -7,7 +7,7 @@ import java.time.Duration; import java.util.List; -import io.modelcontextprotocol.MockMcpTransport; +import io.modelcontextprotocol.MockMcpClientTransport; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.InitializeResult; @@ -28,7 +28,7 @@ class McpClientProtocolVersionTests { @Test void shouldUseLatestVersionByDefault() { - MockMcpTransport transport = new MockMcpTransport(); + MockMcpClientTransport transport = new MockMcpClientTransport(); McpAsyncClient client = McpClient.async(transport) .clientInfo(CLIENT_INFO) .requestTimeout(REQUEST_TIMEOUT) @@ -61,7 +61,7 @@ void shouldUseLatestVersionByDefault() { @Test void shouldNegotiateSpecificVersion() { String oldVersion = "0.1.0"; - MockMcpTransport transport = new MockMcpTransport(); + MockMcpClientTransport transport = new MockMcpClientTransport(); McpAsyncClient client = McpClient.async(transport) .clientInfo(CLIENT_INFO) .requestTimeout(REQUEST_TIMEOUT) @@ -94,7 +94,7 @@ void shouldNegotiateSpecificVersion() { @Test void shouldFailForUnsupportedVersion() { String unsupportedVersion = "999.999.999"; - MockMcpTransport transport = new MockMcpTransport(); + MockMcpClientTransport transport = new MockMcpClientTransport(); McpAsyncClient client = McpClient.async(transport) .clientInfo(CLIENT_INFO) .requestTimeout(REQUEST_TIMEOUT) @@ -124,7 +124,7 @@ void shouldUseHighestVersionWhenMultipleSupported() { String middleVersion = "0.2.0"; String latestVersion = McpSchema.LATEST_PROTOCOL_VERSION; - MockMcpTransport transport = new MockMcpTransport(); + MockMcpClientTransport transport = new MockMcpClientTransport(); McpAsyncClient client = McpClient.async(transport) .clientInfo(CLIENT_INFO) .requestTimeout(REQUEST_TIMEOUT) diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerDeprecatedTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerDeprecatedTests.java deleted file mode 100644 index b9a19de6c..000000000 --- a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerDeprecatedTests.java +++ /dev/null @@ -1,466 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.server; - -import java.time.Duration; -import java.util.List; - -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpSchema.CallToolResult; -import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; -import io.modelcontextprotocol.spec.McpSchema.Prompt; -import io.modelcontextprotocol.spec.McpSchema.PromptMessage; -import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; -import io.modelcontextprotocol.spec.McpSchema.Resource; -import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; -import io.modelcontextprotocol.spec.McpSchema.Tool; -import io.modelcontextprotocol.spec.McpTransport; -import io.modelcontextprotocol.spec.ServerMcpTransport; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import reactor.core.publisher.Mono; -import reactor.test.StepVerifier; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; -import static org.assertj.core.api.Assertions.assertThatThrownBy; - -/** - * Test suite for the {@link McpAsyncServer} that can be used with different - * {@link McpTransport} implementations. - * - * @author Christian Tzolov - */ -// KEEP IN SYNC with the class in mcp-test module -@Deprecated -public abstract class AbstractMcpAsyncServerDeprecatedTests { - - private static final String TEST_TOOL_NAME = "test-tool"; - - private static final String TEST_RESOURCE_URI = "test://resource"; - - private static final String TEST_PROMPT_NAME = "test-prompt"; - - abstract protected ServerMcpTransport createMcpTransport(); - - protected void onStart() { - } - - protected void onClose() { - } - - @BeforeEach - void setUp() { - } - - @AfterEach - void tearDown() { - onClose(); - } - - // --------------------------------------- - // Server Lifecycle Tests - // --------------------------------------- - - @Test - void testConstructorWithInvalidArguments() { - assertThatThrownBy(() -> McpServer.async((ServerMcpTransport) null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("Transport must not be null"); - - assertThatThrownBy(() -> McpServer.async(createMcpTransport()).serverInfo((McpSchema.Implementation) null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("Server info must not be null"); - } - - @Test - void testGracefulShutdown() { - var mcpAsyncServer = McpServer.async(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); - - StepVerifier.create(mcpAsyncServer.closeGracefully()).verifyComplete(); - } - - @Test - void testImmediateClose() { - var mcpAsyncServer = McpServer.async(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); - - assertThatCode(() -> mcpAsyncServer.close()).doesNotThrowAnyException(); - } - - // --------------------------------------- - // Tools Tests - // --------------------------------------- - String emptyJsonSchema = """ - { - "$schema": "http://json-schema.org/draft-07/schema#", - "type": "object", - "properties": {} - } - """; - - @Test - void testAddTool() { - Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema); - var mcpAsyncServer = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .build(); - - StepVerifier.create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolRegistration(newTool, - args -> Mono.just(new CallToolResult(List.of(), false))))) - .verifyComplete(); - - assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - } - - @Test - void testAddDuplicateTool() { - Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - - var mcpAsyncServer = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(duplicateTool, args -> Mono.just(new CallToolResult(List.of(), false))) - .build(); - - StepVerifier.create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolRegistration(duplicateTool, - args -> Mono.just(new CallToolResult(List.of(), false))))) - .verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Tool with name '" + TEST_TOOL_NAME + "' already exists"); - }); - - assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - } - - @Test - void testRemoveTool() { - Tool too = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - - var mcpAsyncServer = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(too, args -> Mono.just(new CallToolResult(List.of(), false))) - .build(); - - StepVerifier.create(mcpAsyncServer.removeTool(TEST_TOOL_NAME)).verifyComplete(); - - assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - } - - @Test - void testRemoveNonexistentTool() { - var mcpAsyncServer = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .build(); - - StepVerifier.create(mcpAsyncServer.removeTool("nonexistent-tool")).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class).hasMessage("Tool with name 'nonexistent-tool' not found"); - }); - - assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - } - - @Test - void testNotifyToolsListChanged() { - Tool too = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - - var mcpAsyncServer = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(too, args -> Mono.just(new CallToolResult(List.of(), false))) - .build(); - - StepVerifier.create(mcpAsyncServer.notifyToolsListChanged()).verifyComplete(); - - assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - } - - // --------------------------------------- - // Resources Tests - // --------------------------------------- - - @Test - void testNotifyResourcesListChanged() { - var mcpAsyncServer = McpServer.async(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); - - StepVerifier.create(mcpAsyncServer.notifyResourcesListChanged()).verifyComplete(); - - assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - } - - @Test - void testAddResource() { - var mcpAsyncServer = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().resources(true, false).build()) - .build(); - - Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", - null); - McpServerFeatures.AsyncResourceRegistration registration = new McpServerFeatures.AsyncResourceRegistration( - resource, req -> Mono.just(new ReadResourceResult(List.of()))); - - StepVerifier.create(mcpAsyncServer.addResource(registration)).verifyComplete(); - - assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - } - - @Test - void testAddResourceWithNullRegistration() { - var mcpAsyncServer = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().resources(true, false).build()) - .build(); - - StepVerifier.create(mcpAsyncServer.addResource((McpServerFeatures.AsyncResourceRegistration) null)) - .verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class).hasMessage("Resource must not be null"); - }); - - assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - } - - @Test - void testAddResourceWithoutCapability() { - // Create a server without resource capabilities - McpAsyncServer serverWithoutResources = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .build(); - - Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", - null); - McpServerFeatures.AsyncResourceRegistration registration = new McpServerFeatures.AsyncResourceRegistration( - resource, req -> Mono.just(new ReadResourceResult(List.of()))); - - StepVerifier.create(serverWithoutResources.addResource(registration)).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Server must be configured with resource capabilities"); - }); - } - - @Test - void testRemoveResourceWithoutCapability() { - // Create a server without resource capabilities - McpAsyncServer serverWithoutResources = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .build(); - - StepVerifier.create(serverWithoutResources.removeResource(TEST_RESOURCE_URI)).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Server must be configured with resource capabilities"); - }); - } - - // --------------------------------------- - // Prompts Tests - // --------------------------------------- - - @Test - void testNotifyPromptsListChanged() { - var mcpAsyncServer = McpServer.async(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); - - StepVerifier.create(mcpAsyncServer.notifyPromptsListChanged()).verifyComplete(); - - assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - } - - @Test - void testAddPromptWithNullRegistration() { - var mcpAsyncServer = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().prompts(false).build()) - .build(); - - StepVerifier.create(mcpAsyncServer.addPrompt((McpServerFeatures.AsyncPromptRegistration) null)) - .verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class).hasMessage("Prompt registration must not be null"); - }); - } - - @Test - void testAddPromptWithoutCapability() { - // Create a server without prompt capabilities - McpAsyncServer serverWithoutPrompts = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .build(); - - Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", List.of()); - McpServerFeatures.AsyncPromptRegistration registration = new McpServerFeatures.AsyncPromptRegistration(prompt, - req -> Mono.just(new GetPromptResult("Test prompt description", List - .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content")))))); - - StepVerifier.create(serverWithoutPrompts.addPrompt(registration)).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Server must be configured with prompt capabilities"); - }); - } - - @Test - void testRemovePromptWithoutCapability() { - // Create a server without prompt capabilities - McpAsyncServer serverWithoutPrompts = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .build(); - - StepVerifier.create(serverWithoutPrompts.removePrompt(TEST_PROMPT_NAME)).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Server must be configured with prompt capabilities"); - }); - } - - @Test - void testRemovePrompt() { - String TEST_PROMPT_NAME_TO_REMOVE = "TEST_PROMPT_NAME678"; - - Prompt prompt = new Prompt(TEST_PROMPT_NAME_TO_REMOVE, "Test Prompt", List.of()); - McpServerFeatures.AsyncPromptRegistration registration = new McpServerFeatures.AsyncPromptRegistration(prompt, - req -> Mono.just(new GetPromptResult("Test prompt description", List - .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content")))))); - - var mcpAsyncServer = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().prompts(true).build()) - .prompts(registration) - .build(); - - StepVerifier.create(mcpAsyncServer.removePrompt(TEST_PROMPT_NAME_TO_REMOVE)).verifyComplete(); - - assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - } - - @Test - void testRemoveNonexistentPrompt() { - var mcpAsyncServer2 = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().prompts(true).build()) - .build(); - - StepVerifier.create(mcpAsyncServer2.removePrompt("nonexistent-prompt")).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Prompt with name 'nonexistent-prompt' not found"); - }); - - assertThatCode(() -> mcpAsyncServer2.closeGracefully().block(Duration.ofSeconds(10))) - .doesNotThrowAnyException(); - } - - // --------------------------------------- - // Roots Tests - // --------------------------------------- - - @Test - void testRootsChangeConsumers() { - // Test with single consumer - var rootsReceived = new McpSchema.Root[1]; - var consumerCalled = new boolean[1]; - - var singleConsumerServer = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .rootsChangeConsumers(List.of(roots -> Mono.fromRunnable(() -> { - consumerCalled[0] = true; - if (!roots.isEmpty()) { - rootsReceived[0] = roots.get(0); - } - }))) - .build(); - - assertThat(singleConsumerServer).isNotNull(); - assertThatCode(() -> singleConsumerServer.closeGracefully().block(Duration.ofSeconds(10))) - .doesNotThrowAnyException(); - onClose(); - - // Test with multiple consumers - var consumer1Called = new boolean[1]; - var consumer2Called = new boolean[1]; - var rootsContent = new List[1]; - - var multipleConsumersServer = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .rootsChangeConsumers(List.of(roots -> Mono.fromRunnable(() -> { - consumer1Called[0] = true; - rootsContent[0] = roots; - }), roots -> Mono.fromRunnable(() -> consumer2Called[0] = true))) - .build(); - - assertThat(multipleConsumersServer).isNotNull(); - assertThatCode(() -> multipleConsumersServer.closeGracefully().block(Duration.ofSeconds(10))) - .doesNotThrowAnyException(); - onClose(); - - // Test error handling - var errorHandlingServer = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .rootsChangeConsumers(List.of(roots -> { - throw new RuntimeException("Test error"); - })) - .build(); - - assertThat(errorHandlingServer).isNotNull(); - assertThatCode(() -> errorHandlingServer.closeGracefully().block(Duration.ofSeconds(10))) - .doesNotThrowAnyException(); - onClose(); - - // Test without consumers - var noConsumersServer = McpServer.async(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); - - assertThat(noConsumersServer).isNotNull(); - assertThatCode(() -> noConsumersServer.closeGracefully().block(Duration.ofSeconds(10))) - .doesNotThrowAnyException(); - } - - // --------------------------------------- - // Logging Tests - // --------------------------------------- - - @Test - void testLoggingLevels() { - var mcpAsyncServer = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().logging().build()) - .build(); - - // Test all logging levels - for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { - var notification = McpSchema.LoggingMessageNotification.builder() - .level(level) - .logger("test-logger") - .data("Test message with level " + level) - .build(); - - StepVerifier.create(mcpAsyncServer.loggingNotification(notification)).verifyComplete(); - } - } - - @Test - void testLoggingWithoutCapability() { - var mcpAsyncServer = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().build()) // No logging capability - .build(); - - var notification = McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.INFO) - .logger("test-logger") - .data("Test log message") - .build(); - - StepVerifier.create(mcpAsyncServer.loggingNotification(notification)).verifyComplete(); - } - - @Test - void testLoggingWithNullNotification() { - var mcpAsyncServer = McpServer.async(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().logging().build()) - .build(); - - StepVerifier.create(mcpAsyncServer.loggingNotification(null)).verifyError(McpError.class); - } - -} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerDeprecatedTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerDeprecatedTests.java deleted file mode 100644 index 16bc2d6e4..000000000 --- a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerDeprecatedTests.java +++ /dev/null @@ -1,433 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.server; - -import java.util.List; - -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpSchema.CallToolResult; -import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; -import io.modelcontextprotocol.spec.McpSchema.Prompt; -import io.modelcontextprotocol.spec.McpSchema.PromptMessage; -import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; -import io.modelcontextprotocol.spec.McpSchema.Resource; -import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; -import io.modelcontextprotocol.spec.McpSchema.Tool; -import io.modelcontextprotocol.spec.McpTransport; -import io.modelcontextprotocol.spec.ServerMcpTransport; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; -import static org.assertj.core.api.Assertions.assertThatThrownBy; - -/** - * Test suite for the {@link McpSyncServer} that can be used with different - * {@link McpTransport} implementations. - * - * @author Christian Tzolov - */ -// KEEP IN SYNC with the class in mcp-test module -@Deprecated -public abstract class AbstractMcpSyncServerDeprecatedTests { - - private static final String TEST_TOOL_NAME = "test-tool"; - - private static final String TEST_RESOURCE_URI = "test://resource"; - - private static final String TEST_PROMPT_NAME = "test-prompt"; - - abstract protected ServerMcpTransport createMcpTransport(); - - protected void onStart() { - } - - protected void onClose() { - } - - @BeforeEach - void setUp() { - // onStart(); - } - - @AfterEach - void tearDown() { - onClose(); - } - - // --------------------------------------- - // Server Lifecycle Tests - // --------------------------------------- - - @Test - void testConstructorWithInvalidArguments() { - assertThatThrownBy(() -> McpServer.sync((ServerMcpTransport) null)).isInstanceOf(IllegalArgumentException.class) - .hasMessage("Transport must not be null"); - - assertThatThrownBy(() -> McpServer.sync(createMcpTransport()).serverInfo(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("Server info must not be null"); - } - - @Test - void testGracefulShutdown() { - var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - @Test - void testImmediateClose() { - var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); - - assertThatCode(() -> mcpSyncServer.close()).doesNotThrowAnyException(); - } - - @Test - void testGetAsyncServer() { - var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); - - assertThat(mcpSyncServer.getAsyncServer()).isNotNull(); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - // --------------------------------------- - // Tools Tests - // --------------------------------------- - - String emptyJsonSchema = """ - { - "$schema": "http://json-schema.org/draft-07/schema#", - "type": "object", - "properties": {} - } - """; - - @Test - void testAddTool() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .build(); - - Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema); - assertThatCode(() -> mcpSyncServer - .addTool(new McpServerFeatures.SyncToolRegistration(newTool, args -> new CallToolResult(List.of(), false)))) - .doesNotThrowAnyException(); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - @Test - void testAddDuplicateTool() { - Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - - var mcpSyncServer = McpServer.sync(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(duplicateTool, args -> new CallToolResult(List.of(), false)) - .build(); - - assertThatThrownBy(() -> mcpSyncServer.addTool(new McpServerFeatures.SyncToolRegistration(duplicateTool, - args -> new CallToolResult(List.of(), false)))) - .isInstanceOf(McpError.class) - .hasMessage("Tool with name '" + TEST_TOOL_NAME + "' already exists"); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - @Test - void testRemoveTool() { - Tool tool = new McpSchema.Tool(TEST_TOOL_NAME, "Test tool", emptyJsonSchema); - - var mcpSyncServer = McpServer.sync(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(tool, args -> new CallToolResult(List.of(), false)) - .build(); - - assertThatCode(() -> mcpSyncServer.removeTool(TEST_TOOL_NAME)).doesNotThrowAnyException(); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - @Test - void testRemoveNonexistentTool() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .build(); - - assertThatThrownBy(() -> mcpSyncServer.removeTool("nonexistent-tool")).isInstanceOf(McpError.class) - .hasMessage("Tool with name 'nonexistent-tool' not found"); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - @Test - void testNotifyToolsListChanged() { - var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); - - assertThatCode(() -> mcpSyncServer.notifyToolsListChanged()).doesNotThrowAnyException(); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - // --------------------------------------- - // Resources Tests - // --------------------------------------- - - @Test - void testNotifyResourcesListChanged() { - var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); - - assertThatCode(() -> mcpSyncServer.notifyResourcesListChanged()).doesNotThrowAnyException(); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - @Test - void testAddResource() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().resources(true, false).build()) - .build(); - - Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", - null); - McpServerFeatures.SyncResourceRegistration registration = new McpServerFeatures.SyncResourceRegistration( - resource, req -> new ReadResourceResult(List.of())); - - assertThatCode(() -> mcpSyncServer.addResource(registration)).doesNotThrowAnyException(); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - @Test - void testAddResourceWithNullRegistration() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().resources(true, false).build()) - .build(); - - assertThatThrownBy(() -> mcpSyncServer.addResource((McpServerFeatures.SyncResourceRegistration) null)) - .isInstanceOf(McpError.class) - .hasMessage("Resource must not be null"); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - @Test - void testAddResourceWithoutCapability() { - var serverWithoutResources = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); - - Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", - null); - McpServerFeatures.SyncResourceRegistration registration = new McpServerFeatures.SyncResourceRegistration( - resource, req -> new ReadResourceResult(List.of())); - - assertThatThrownBy(() -> serverWithoutResources.addResource(registration)).isInstanceOf(McpError.class) - .hasMessage("Server must be configured with resource capabilities"); - } - - @Test - void testRemoveResourceWithoutCapability() { - var serverWithoutResources = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); - - assertThatThrownBy(() -> serverWithoutResources.removeResource(TEST_RESOURCE_URI)).isInstanceOf(McpError.class) - .hasMessage("Server must be configured with resource capabilities"); - } - - // --------------------------------------- - // Prompts Tests - // --------------------------------------- - - @Test - void testNotifyPromptsListChanged() { - var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); - - assertThatCode(() -> mcpSyncServer.notifyPromptsListChanged()).doesNotThrowAnyException(); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - @Test - void testAddPromptWithNullRegistration() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().prompts(false).build()) - .build(); - - assertThatThrownBy(() -> mcpSyncServer.addPrompt((McpServerFeatures.SyncPromptRegistration) null)) - .isInstanceOf(McpError.class) - .hasMessage("Prompt registration must not be null"); - } - - @Test - void testAddPromptWithoutCapability() { - var serverWithoutPrompts = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); - - Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", List.of()); - McpServerFeatures.SyncPromptRegistration registration = new McpServerFeatures.SyncPromptRegistration(prompt, - req -> new GetPromptResult("Test prompt description", List - .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content"))))); - - assertThatThrownBy(() -> serverWithoutPrompts.addPrompt(registration)).isInstanceOf(McpError.class) - .hasMessage("Server must be configured with prompt capabilities"); - } - - @Test - void testRemovePromptWithoutCapability() { - var serverWithoutPrompts = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); - - assertThatThrownBy(() -> serverWithoutPrompts.removePrompt(TEST_PROMPT_NAME)).isInstanceOf(McpError.class) - .hasMessage("Server must be configured with prompt capabilities"); - } - - @Test - void testRemovePrompt() { - Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", List.of()); - McpServerFeatures.SyncPromptRegistration registration = new McpServerFeatures.SyncPromptRegistration(prompt, - req -> new GetPromptResult("Test prompt description", List - .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content"))))); - - var mcpSyncServer = McpServer.sync(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().prompts(true).build()) - .prompts(registration) - .build(); - - assertThatCode(() -> mcpSyncServer.removePrompt(TEST_PROMPT_NAME)).doesNotThrowAnyException(); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - @Test - void testRemoveNonexistentPrompt() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().prompts(true).build()) - .build(); - - assertThatThrownBy(() -> mcpSyncServer.removePrompt("nonexistent-prompt")).isInstanceOf(McpError.class) - .hasMessage("Prompt with name 'nonexistent-prompt' not found"); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - // --------------------------------------- - // Roots Tests - // --------------------------------------- - - @Test - void testRootsChangeConsumers() { - // Test with single consumer - var rootsReceived = new McpSchema.Root[1]; - var consumerCalled = new boolean[1]; - - var singleConsumerServer = McpServer.sync(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .rootsChangeConsumers(List.of(roots -> { - consumerCalled[0] = true; - if (!roots.isEmpty()) { - rootsReceived[0] = roots.get(0); - } - })) - .build(); - - assertThat(singleConsumerServer).isNotNull(); - assertThatCode(() -> singleConsumerServer.closeGracefully()).doesNotThrowAnyException(); - onClose(); - - // Test with multiple consumers - var consumer1Called = new boolean[1]; - var consumer2Called = new boolean[1]; - var rootsContent = new List[1]; - - var multipleConsumersServer = McpServer.sync(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .rootsChangeConsumers(List.of(roots -> { - consumer1Called[0] = true; - rootsContent[0] = roots; - }, roots -> consumer2Called[0] = true)) - .build(); - - assertThat(multipleConsumersServer).isNotNull(); - assertThatCode(() -> multipleConsumersServer.closeGracefully()).doesNotThrowAnyException(); - onClose(); - - // Test error handling - var errorHandlingServer = McpServer.sync(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .rootsChangeConsumers(List.of(roots -> { - throw new RuntimeException("Test error"); - })) - .build(); - - assertThat(errorHandlingServer).isNotNull(); - assertThatCode(() -> errorHandlingServer.closeGracefully()).doesNotThrowAnyException(); - onClose(); - - // Test without consumers - var noConsumersServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); - - assertThat(noConsumersServer).isNotNull(); - assertThatCode(() -> noConsumersServer.closeGracefully()).doesNotThrowAnyException(); - } - - // --------------------------------------- - // Logging Tests - // --------------------------------------- - - @Test - void testLoggingLevels() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().logging().build()) - .build(); - - // Test all logging levels - for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { - var notification = McpSchema.LoggingMessageNotification.builder() - .level(level) - .logger("test-logger") - .data("Test message with level " + level) - .build(); - - assertThatCode(() -> mcpSyncServer.loggingNotification(notification)).doesNotThrowAnyException(); - } - } - - @Test - void testLoggingWithoutCapability() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().build()) // No logging capability - .build(); - - var notification = McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.INFO) - .logger("test-logger") - .data("Test log message") - .build(); - - assertThatCode(() -> mcpSyncServer.loggingNotification(notification)).doesNotThrowAnyException(); - } - - @Test - void testLoggingWithNullNotification() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().logging().build()) - .build(); - - assertThatThrownBy(() -> mcpSyncServer.loggingNotification(null)).isInstanceOf(McpError.class); - } - -} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/McpServerProtocolVersionTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/McpServerProtocolVersionTests.java index 97358723f..f643f1ba3 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/McpServerProtocolVersionTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/McpServerProtocolVersionTests.java @@ -7,7 +7,8 @@ import java.util.List; import java.util.UUID; -import io.modelcontextprotocol.MockMcpTransport; +import io.modelcontextprotocol.MockMcpServerTransport; +import io.modelcontextprotocol.MockMcpServerTransportProvider; import io.modelcontextprotocol.spec.McpSchema; import org.junit.jupiter.api.Test; @@ -29,14 +30,16 @@ private McpSchema.JSONRPCRequest jsonRpcInitializeRequest(String requestId, Stri @Test void shouldUseLatestVersionByDefault() { - MockMcpTransport transport = new MockMcpTransport(); - McpAsyncServer server = McpServer.async(transport).serverInfo(SERVER_INFO).build(); + MockMcpServerTransport serverTransport = new MockMcpServerTransport(); + var transportProvider = new MockMcpServerTransportProvider(serverTransport); + McpAsyncServer server = McpServer.async(transportProvider).serverInfo(SERVER_INFO).build(); String requestId = UUID.randomUUID().toString(); - transport.simulateIncomingMessage(jsonRpcInitializeRequest(requestId, McpSchema.LATEST_PROTOCOL_VERSION)); + transportProvider + .simulateIncomingMessage(jsonRpcInitializeRequest(requestId, McpSchema.LATEST_PROTOCOL_VERSION)); - McpSchema.JSONRPCMessage response = transport.getLastSentMessage(); + McpSchema.JSONRPCMessage response = serverTransport.getLastSentMessage(); assertThat(response).isInstanceOf(McpSchema.JSONRPCResponse.class); McpSchema.JSONRPCResponse jsonResponse = (McpSchema.JSONRPCResponse) response; assertThat(jsonResponse.id()).isEqualTo(requestId); @@ -50,16 +53,18 @@ void shouldUseLatestVersionByDefault() { @Test void shouldNegotiateSpecificVersion() { String oldVersion = "0.1.0"; - MockMcpTransport transport = new MockMcpTransport(); - McpAsyncServer server = McpServer.async(transport).serverInfo(SERVER_INFO).build(); + MockMcpServerTransport serverTransport = new MockMcpServerTransport(); + var transportProvider = new MockMcpServerTransportProvider(serverTransport); + + McpAsyncServer server = McpServer.async(transportProvider).serverInfo(SERVER_INFO).build(); server.setProtocolVersions(List.of(oldVersion, McpSchema.LATEST_PROTOCOL_VERSION)); String requestId = UUID.randomUUID().toString(); - transport.simulateIncomingMessage(jsonRpcInitializeRequest(requestId, oldVersion)); + transportProvider.simulateIncomingMessage(jsonRpcInitializeRequest(requestId, oldVersion)); - McpSchema.JSONRPCMessage response = transport.getLastSentMessage(); + McpSchema.JSONRPCMessage response = serverTransport.getLastSentMessage(); assertThat(response).isInstanceOf(McpSchema.JSONRPCResponse.class); McpSchema.JSONRPCResponse jsonResponse = (McpSchema.JSONRPCResponse) response; assertThat(jsonResponse.id()).isEqualTo(requestId); @@ -73,14 +78,16 @@ void shouldNegotiateSpecificVersion() { @Test void shouldSuggestLatestVersionForUnsupportedVersion() { String unsupportedVersion = "999.999.999"; - MockMcpTransport transport = new MockMcpTransport(); - McpAsyncServer server = McpServer.async(transport).serverInfo(SERVER_INFO).build(); + MockMcpServerTransport serverTransport = new MockMcpServerTransport(); + var transportProvider = new MockMcpServerTransportProvider(serverTransport); + + McpAsyncServer server = McpServer.async(transportProvider).serverInfo(SERVER_INFO).build(); String requestId = UUID.randomUUID().toString(); - transport.simulateIncomingMessage(jsonRpcInitializeRequest(requestId, unsupportedVersion)); + transportProvider.simulateIncomingMessage(jsonRpcInitializeRequest(requestId, unsupportedVersion)); - McpSchema.JSONRPCMessage response = transport.getLastSentMessage(); + McpSchema.JSONRPCMessage response = serverTransport.getLastSentMessage(); assertThat(response).isInstanceOf(McpSchema.JSONRPCResponse.class); McpSchema.JSONRPCResponse jsonResponse = (McpSchema.JSONRPCResponse) response; assertThat(jsonResponse.id()).isEqualTo(requestId); @@ -97,15 +104,17 @@ void shouldUseHighestVersionWhenMultipleSupported() { String middleVersion = "0.2.0"; String latestVersion = McpSchema.LATEST_PROTOCOL_VERSION; - MockMcpTransport transport = new MockMcpTransport(); - McpAsyncServer server = McpServer.async(transport).serverInfo(SERVER_INFO).build(); + MockMcpServerTransport serverTransport = new MockMcpServerTransport(); + var transportProvider = new MockMcpServerTransportProvider(serverTransport); + + McpAsyncServer server = McpServer.async(transportProvider).serverInfo(SERVER_INFO).build(); server.setProtocolVersions(List.of(oldVersion, middleVersion, latestVersion)); String requestId = UUID.randomUUID().toString(); - transport.simulateIncomingMessage(jsonRpcInitializeRequest(requestId, latestVersion)); + transportProvider.simulateIncomingMessage(jsonRpcInitializeRequest(requestId, latestVersion)); - McpSchema.JSONRPCMessage response = transport.getLastSentMessage(); + McpSchema.JSONRPCMessage response = serverTransport.getLastSentMessage(); assertThat(response).isInstanceOf(McpSchema.JSONRPCResponse.class); McpSchema.JSONRPCResponse jsonResponse = (McpSchema.JSONRPCResponse) response; assertThat(jsonResponse.id()).isEqualTo(requestId); diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpAsyncServerDeprecatedTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpAsyncServerDeprecatedTests.java deleted file mode 100644 index 2c80d45c6..000000000 --- a/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpAsyncServerDeprecatedTests.java +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.server; - -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.server.transport.HttpServletSseServerTransport; -import io.modelcontextprotocol.spec.ServerMcpTransport; -import org.junit.jupiter.api.Timeout; - -/** - * Tests for {@link McpAsyncServer} using {@link HttpServletSseServerTransport}. - * - * @author Christian Tzolov - */ -@Deprecated -@Timeout(15) // Giving extra time beyond the client timeout -class ServletSseMcpAsyncServerDeprecatedTests extends AbstractMcpAsyncServerDeprecatedTests { - - @Override - protected ServerMcpTransport createMcpTransport() { - return new HttpServletSseServerTransport(new ObjectMapper(), "/mcp/message"); - } - -} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpSyncServerDeprecatedTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpSyncServerDeprecatedTests.java deleted file mode 100644 index 8cdd08c5d..000000000 --- a/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpSyncServerDeprecatedTests.java +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.server; - -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.server.transport.HttpServletSseServerTransport; -import io.modelcontextprotocol.spec.ServerMcpTransport; -import org.junit.jupiter.api.Timeout; - -/** - * Tests for {@link McpSyncServer} using {@link HttpServletSseServerTransport}. - * - * @author Christian Tzolov - */ -@Deprecated -@Timeout(15) // Giving extra time beyond the client timeout -class ServletSseMcpSyncServerDeprecatedTests extends AbstractMcpSyncServerDeprecatedTests { - - @Override - protected ServerMcpTransport createMcpTransport() { - return new HttpServletSseServerTransport(new ObjectMapper(), "/mcp/message"); - } - -} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpAsyncServerDeprecatedTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpAsyncServerDeprecatedTests.java deleted file mode 100644 index db95db07b..000000000 --- a/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpAsyncServerDeprecatedTests.java +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.server; - -import io.modelcontextprotocol.server.transport.StdioServerTransport; -import io.modelcontextprotocol.spec.ServerMcpTransport; -import org.junit.jupiter.api.Timeout; - -/** - * Tests for {@link McpAsyncServer} using {@link StdioServerTransport}. - * - * @author Christian Tzolov - */ -@Deprecated -@Timeout(15) // Giving extra time beyond the client timeout -class StdioMcpAsyncServerDeprecatedTests extends AbstractMcpAsyncServerDeprecatedTests { - - @Override - protected ServerMcpTransport createMcpTransport() { - return new StdioServerTransport(); - } - -} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpAsyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpAsyncServerTests.java index 27ff53c93..0381a43bd 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpAsyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpAsyncServerTests.java @@ -4,7 +4,6 @@ package io.modelcontextprotocol.server; -import io.modelcontextprotocol.server.transport.StdioServerTransport; import io.modelcontextprotocol.server.transport.StdioServerTransportProvider; import io.modelcontextprotocol.spec.McpServerTransportProvider; import org.junit.jupiter.api.Timeout; diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpSyncServerDeprecatedTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpSyncServerDeprecatedTests.java deleted file mode 100644 index 149f72819..000000000 --- a/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpSyncServerDeprecatedTests.java +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.server; - -import io.modelcontextprotocol.server.transport.StdioServerTransport; -import io.modelcontextprotocol.spec.ServerMcpTransport; -import org.junit.jupiter.api.Timeout; - -/** - * Tests for {@link McpSyncServer} using {@link StdioServerTransport}. - * - * @author Christian Tzolov - */ -@Deprecated -@Timeout(15) // Giving extra time beyond the client timeout -class StdioMcpSyncServerDeprecatedTests extends AbstractMcpSyncServerDeprecatedTests { - - @Override - protected ServerMcpTransport createMcpTransport() { - return new StdioServerTransport(); - } - -} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportIntegrationTests.java deleted file mode 100644 index 4a292da31..000000000 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportIntegrationTests.java +++ /dev/null @@ -1,328 +0,0 @@ -/* - * Copyright 2024 - 2024 the original author or authors. - */ -package io.modelcontextprotocol.server.transport; - -import java.time.Duration; -import java.util.List; -import java.util.Map; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Function; - -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.client.McpClient; -import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; -import io.modelcontextprotocol.server.McpServer; -import io.modelcontextprotocol.server.McpServerFeatures; -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpSchema.CallToolResult; -import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; -import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; -import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; -import io.modelcontextprotocol.spec.McpSchema.InitializeResult; -import io.modelcontextprotocol.spec.McpSchema.Role; -import io.modelcontextprotocol.spec.McpSchema.Root; -import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; -import io.modelcontextprotocol.spec.McpSchema.Tool; -import org.apache.catalina.Context; -import org.apache.catalina.LifecycleException; -import org.apache.catalina.LifecycleState; -import org.apache.catalina.startup.Tomcat; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import reactor.test.StepVerifier; - -import org.springframework.web.client.RestClient; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.awaitility.Awaitility.await; - -public class HttpServletSseServerTransportIntegrationTests { - - private static final int PORT = 8184; - - private static final String MESSAGE_ENDPOINT = "/mcp/message"; - - private HttpServletSseServerTransport mcpServerTransport; - - McpClient.SyncSpec clientBuilder; - - private Tomcat tomcat; - - @BeforeEach - public void before() { - tomcat = new Tomcat(); - tomcat.setPort(PORT); - - String baseDir = System.getProperty("java.io.tmpdir"); - tomcat.setBaseDir(baseDir); - - Context context = tomcat.addContext("", baseDir); - - // Create and configure the transport - mcpServerTransport = new HttpServletSseServerTransport(new ObjectMapper(), MESSAGE_ENDPOINT); - - // Add transport servlet to Tomcat - org.apache.catalina.Wrapper wrapper = context.createWrapper(); - wrapper.setName("mcpServlet"); - wrapper.setServlet(mcpServerTransport); - wrapper.setLoadOnStartup(1); - wrapper.setAsyncSupported(true); - context.addChild(wrapper); - context.addServletMappingDecoded("/*", "mcpServlet"); - - try { - var connector = tomcat.getConnector(); - connector.setAsyncTimeout(3000); - tomcat.start(); - assertThat(tomcat.getServer().getState() == LifecycleState.STARTED); - } - catch (Exception e) { - throw new RuntimeException("Failed to start Tomcat", e); - } - - this.clientBuilder = McpClient.sync(new HttpClientSseClientTransport("http://localhost:" + PORT)); - } - - @AfterEach - public void after() { - if (mcpServerTransport != null) { - mcpServerTransport.closeGracefully().block(); - } - if (tomcat != null) { - try { - tomcat.stop(); - tomcat.destroy(); - } - catch (LifecycleException e) { - throw new RuntimeException("Failed to stop Tomcat", e); - } - } - } - - @Test - void testCreateMessageWithoutInitialization() { - var mcpAsyncServer = McpServer.async(mcpServerTransport).serverInfo("test-server", "1.0.0").build(); - - var messages = List - .of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Test message"))); - var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); - - var request = new McpSchema.CreateMessageRequest(messages, modelPrefs, null, - McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), Map.of()); - - StepVerifier.create(mcpAsyncServer.createMessage(request)).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized. Call the initialize method first!"); - }); - } - - @Test - void testCreateMessageWithoutSamplingCapabilities() { - var mcpAsyncServer = McpServer.async(mcpServerTransport).serverInfo("test-server", "1.0.0").build(); - - var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")).build(); - - InitializeResult initResult = client.initialize(); - assertThat(initResult).isNotNull(); - - var messages = List - .of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Test message"))); - var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); - - var request = new McpSchema.CreateMessageRequest(messages, modelPrefs, null, - McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), Map.of()); - - StepVerifier.create(mcpAsyncServer.createMessage(request)).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Client must be configured with sampling capabilities"); - }); - } - - @Test - void testCreateMessageSuccess() { - var mcpAsyncServer = McpServer.async(mcpServerTransport).serverInfo("test-server", "1.0.0").build(); - - Function samplingHandler = request -> { - assertThat(request.messages()).hasSize(1); - assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); - - return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", - CreateMessageResult.StopReason.STOP_SEQUENCE); - }; - - var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().sampling().build()) - .sampling(samplingHandler) - .build(); - - InitializeResult initResult = client.initialize(); - assertThat(initResult).isNotNull(); - - var messages = List - .of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Test message"))); - var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); - - var request = new McpSchema.CreateMessageRequest(messages, modelPrefs, null, - McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), Map.of()); - - StepVerifier.create(mcpAsyncServer.createMessage(request)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.role()).isEqualTo(Role.USER); - assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); - assertThat(result.model()).isEqualTo("MockModelName"); - assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); - }).verifyComplete(); - } - - @Test - void testRootsSuccess() { - List roots = List.of(new Root("uri1://", "root1"), new Root("uri2://", "root2")); - - AtomicReference> rootsRef = new AtomicReference<>(); - var mcpServer = McpServer.sync(mcpServerTransport) - .rootsChangeConsumer(rootsUpdate -> rootsRef.set(rootsUpdate)) - .build(); - - var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) - .roots(roots) - .build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThat(rootsRef.get()).isNull(); - - assertThat(mcpServer.listRoots().roots()).containsAll(roots); - - mcpClient.rootsListChangedNotification(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(roots); - }); - - mcpClient.close(); - mcpServer.close(); - } - - String emptyJsonSchema = """ - { - "$schema": "http://json-schema.org/draft-07/schema#", - "type": "object", - "properties": {} - } - """; - - @Test - void testToolCallSuccess() { - var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); - McpServerFeatures.SyncToolRegistration tool1 = new McpServerFeatures.SyncToolRegistration( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), request -> { - String response = RestClient.create() - .get() - .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") - .retrieve() - .body(String.class); - assertThat(response).isNotBlank(); - return callResponse; - }); - - var mcpServer = McpServer.sync(mcpServerTransport) - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool1) - .build(); - - var mcpClient = clientBuilder.build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); - - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); - - mcpClient.close(); - mcpServer.close(); - } - - @Test - void testToolListChangeHandlingSuccess() { - var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); - McpServerFeatures.SyncToolRegistration tool1 = new McpServerFeatures.SyncToolRegistration( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), request -> { - String response = RestClient.create() - .get() - .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") - .retrieve() - .body(String.class); - assertThat(response).isNotBlank(); - return callResponse; - }); - - var mcpServer = McpServer.sync(mcpServerTransport) - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool1) - .build(); - - AtomicReference> toolsRef = new AtomicReference<>(); - var mcpClient = clientBuilder.toolsChangeConsumer(toolsUpdate -> { - String response = RestClient.create() - .get() - .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") - .retrieve() - .body(String.class); - assertThat(response).isNotBlank(); - toolsRef.set(toolsUpdate); - }).build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThat(toolsRef.get()).isNull(); - - assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); - - mcpServer.notifyToolsListChanged(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(toolsRef.get()).containsAll(List.of(tool1.tool())); - }); - - mcpServer.removeTool("tool1"); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(toolsRef.get()).isEmpty(); - }); - - McpServerFeatures.SyncToolRegistration tool2 = new McpServerFeatures.SyncToolRegistration( - new McpSchema.Tool("tool2", "tool2 description", emptyJsonSchema), request -> callResponse); - - mcpServer.addTool(tool2); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(toolsRef.get()).containsAll(List.of(tool2.tool())); - }); - - mcpClient.close(); - mcpServer.close(); - } - - @Test - void testInitialize() { - var mcpServer = McpServer.sync(mcpServerTransport).build(); - var mcpClient = clientBuilder.build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - mcpClient.close(); - mcpServer.close(); - } - -} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportTests.java deleted file mode 100644 index 43e5019fc..000000000 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportTests.java +++ /dev/null @@ -1,157 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.server.transport; - -import java.io.ByteArrayOutputStream; -import java.io.InputStream; -import java.io.PrintStream; -import java.nio.charset.StandardCharsets; -import java.util.Map; - -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Disabled; -import org.junit.jupiter.api.Test; -import reactor.core.publisher.Mono; -import reactor.test.StepVerifier; - -import static org.assertj.core.api.Assertions.assertThat; - -/** - * Tests for {@link StdioServerTransport}. - * - * @author Christian Tzolov - */ -class StdioServerTransportTests { - - private final InputStream originalIn = System.in; - - private final PrintStream originalOut = System.out; - - private final PrintStream originalErr = System.err; - - private ByteArrayOutputStream testOut; - - private ByteArrayOutputStream testErr; - - private PrintStream testOutPrintStream; - - private StdioServerTransport transport; - - private ObjectMapper objectMapper; - - @BeforeEach - void setUp() { - testOut = new ByteArrayOutputStream(); - testErr = new ByteArrayOutputStream(); - testOutPrintStream = new PrintStream(testOut, true); - System.setOut(testOutPrintStream); - System.setErr(new PrintStream(testErr)); - - objectMapper = new ObjectMapper(); - } - - @AfterEach - void tearDown() { - if (transport != null) { - transport.closeGracefully().block(); - } - if (testOutPrintStream != null) { - testOutPrintStream.close(); - } - System.setIn(originalIn); - System.setOut(originalOut); - System.setErr(originalErr); - } - - @Test - void shouldHandleIncomingMessages() throws Exception { - // Prepare test input - String jsonMessage = "{\"jsonrpc\":\"2.0\",\"method\":\"test\",\"params\":{},\"id\":1}"; - - // Create transport with test streams - transport = new StdioServerTransport(objectMapper); - - // Parse expected message - McpSchema.JSONRPCRequest expected = objectMapper.readValue(jsonMessage, McpSchema.JSONRPCRequest.class); - - // Connect transport with message handler and verify message - StepVerifier.create(transport.connect(message -> message.doOnNext(msg -> { - McpSchema.JSONRPCRequest received = (McpSchema.JSONRPCRequest) msg; - assertThat(received.id()).isEqualTo(expected.id()); - assertThat(received.method()).isEqualTo(expected.method()); - }))).verifyComplete(); - } - - @Test - @Disabled - void shouldHandleOutgoingMessages() throws Exception { - // Create transport with test streams - transport = new StdioServerTransport(objectMapper); - // transport = new StdioServerTransport(objectMapper, new BlockingInputStream(), - // testOutPrintStream); - - // Create test messages - JSONRPCRequest initMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "init", "init-id", - Map.of("init", "true")); - JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test", "test-id", - Map.of("key", "value")); - - // Connect transport, send messages, and verify output in a reactive chain - StepVerifier.create(transport.connect(message -> message) - .then(transport.sendMessage(initMessage)) - // .then(Mono.fromRunnable(() -> testOut.reset())) // Clear buffer after init - // message - .then(transport.sendMessage(testMessage)) - .then(Mono.fromCallable(() -> { - String output = testOut.toString(StandardCharsets.UTF_8); - assertThat(output).contains("\"jsonrpc\":\"2.0\""); - assertThat(output).contains("\"method\":\"test\""); - assertThat(output).contains("\"id\":\"test-id\""); - return null; - }))).verifyComplete(); - } - - @Test - void shouldWaitForProcessorsBeforeSendingMessage() { - // Create transport with test streams - transport = new StdioServerTransport(objectMapper); - - // Create test message - JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test", "test-id", - Map.of("key", "value")); - - // Try to send message before connecting (before processors are ready) - StepVerifier.create(transport.sendMessage(testMessage)).verifyTimeout(java.time.Duration.ofMillis(100)); - - // Connect transport and verify message can be sent - StepVerifier.create(transport.connect(message -> message).then(transport.sendMessage(testMessage))) - .verifyComplete(); - } - - @Test - void shouldCloseGracefully() { - // Create transport with test streams - transport = new StdioServerTransport(objectMapper); - - // Create test message - JSONRPCRequest initMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "init", "init-id", - Map.of("init", "true")); - - // Connect transport, send message, and close gracefully in a reactive chain - StepVerifier - .create(transport.connect(message -> message) - .then(transport.sendMessage(initMessage)) - .then(transport.closeGracefully())) - .verifyComplete(); - - // Verify error log is empty - assertThat(testErr.toString()).doesNotContain("Error"); - } - -} diff --git a/mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java b/mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java index 79a1d0d92..715d6651e 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java @@ -8,7 +8,7 @@ import java.util.Map; import com.fasterxml.jackson.core.type.TypeReference; -import io.modelcontextprotocol.MockMcpTransport; +import io.modelcontextprotocol.MockMcpClientTransport; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -41,11 +41,11 @@ class McpClientSessionTests { private McpClientSession session; - private MockMcpTransport transport; + private MockMcpClientTransport transport; @BeforeEach void setUp() { - transport = new MockMcpTransport(); + transport = new MockMcpClientTransport(); session = new McpClientSession(TIMEOUT, transport, Map.of(), Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> logger.info("Status update: " + params)))); } @@ -139,7 +139,7 @@ void testRequestHandling() { String echoMessage = "Hello MCP!"; Map> requestHandlers = Map.of(ECHO_METHOD, params -> Mono.just(params)); - transport = new MockMcpTransport(); + transport = new MockMcpClientTransport(); session = new McpClientSession(TIMEOUT, transport, requestHandlers, Map.of()); // Simulate incoming request @@ -159,7 +159,7 @@ void testRequestHandling() { void testNotificationHandling() { Sinks.One receivedParams = Sinks.one(); - transport = new MockMcpTransport(); + transport = new MockMcpClientTransport(); session = new McpClientSession(TIMEOUT, transport, Map.of(), Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> receivedParams.tryEmitValue(params)))); From 2934635d99cd83bcafe26051e1a231bb6681cfbd Mon Sep 17 00:00:00 2001 From: or-givati Date: Mon, 24 Mar 2025 14:41:07 +0200 Subject: [PATCH 003/205] feat(mcp): customize transport endpoints and improve URI handling (#69) - Add support for customizable SSE endpoints in HttpClientSseClientTransport - Replace pathInfo with requestURI in HttpServletSseServerTransportProvider for more reliable endpoint matching - Implement builder pattern to support the customization options Related to #40 Signed-off-by: Christian Tzolov Co-authored-by: Christian Tzolov --- .../HttpClientSseClientTransport.java | 96 ++++++++++++++++++- ...HttpServletSseServerTransportProvider.java | 86 ++++++++++++++++- .../server/ServletSseMcpAsyncServerTests.java | 3 +- .../server/ServletSseMcpSyncServerTests.java | 3 +- ...rverTransportProviderIntegrationTests.java | 14 ++- 5 files changed, 189 insertions(+), 13 deletions(-) diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java index ca1b0e87a..696efdffd 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java @@ -65,11 +65,14 @@ public class HttpClientSseClientTransport implements McpClientTransport { private static final String ENDPOINT_EVENT_TYPE = "endpoint"; /** Default SSE endpoint path */ - private static final String SSE_ENDPOINT = "/sse"; + private static final String DEFAULT_SSE_ENDPOINT = "/sse"; /** Base URI for the MCP server */ private final String baseUri; + /** SSE endpoint path */ + private final String sseEndpoint; + /** SSE client for handling server-sent events. Uses the /sse endpoint */ private final FlowSseClient sseClient; @@ -110,15 +113,104 @@ public HttpClientSseClientTransport(String baseUri) { * @throws IllegalArgumentException if objectMapper or clientBuilder is null */ public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, String baseUri, ObjectMapper objectMapper) { + this(clientBuilder, baseUri, DEFAULT_SSE_ENDPOINT, objectMapper); + } + + /** + * Creates a new transport instance with custom HTTP client builder and object mapper. + * @param clientBuilder the HTTP client builder to use + * @param baseUri the base URI of the MCP server + * @param sseEndpoint the SSE endpoint path + * @param objectMapper the object mapper for JSON serialization/deserialization + * @throws IllegalArgumentException if objectMapper or clientBuilder is null + */ + public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, String baseUri, String sseEndpoint, + ObjectMapper objectMapper) { Assert.notNull(objectMapper, "ObjectMapper must not be null"); Assert.hasText(baseUri, "baseUri must not be empty"); + Assert.hasText(sseEndpoint, "sseEndpoint must not be empty"); Assert.notNull(clientBuilder, "clientBuilder must not be null"); this.baseUri = baseUri; + this.sseEndpoint = sseEndpoint; this.objectMapper = objectMapper; this.httpClient = clientBuilder.connectTimeout(Duration.ofSeconds(10)).build(); this.sseClient = new FlowSseClient(this.httpClient); } + /** + * Creates a new builder for {@link HttpClientSseClientTransport}. + * @param baseUri the base URI of the MCP server + * @return a new builder instance + */ + public static Builder builder(String baseUri) { + return new Builder(baseUri); + } + + /** + * Builder for {@link HttpClientSseClientTransport}. + */ + public static class Builder { + + private final String baseUri; + + private String sseEndpoint = DEFAULT_SSE_ENDPOINT; + + private HttpClient.Builder clientBuilder = HttpClient.newBuilder(); + + private ObjectMapper objectMapper = new ObjectMapper(); + + /** + * Creates a new builder with the specified base URI. + * @param baseUri the base URI of the MCP server + */ + public Builder(String baseUri) { + Assert.hasText(baseUri, "baseUri must not be empty"); + this.baseUri = baseUri; + } + + /** + * Sets the SSE endpoint path. + * @param sseEndpoint the SSE endpoint path + * @return this builder + */ + public Builder sseEndpoint(String sseEndpoint) { + Assert.hasText(sseEndpoint, "sseEndpoint must not be empty"); + this.sseEndpoint = sseEndpoint; + return this; + } + + /** + * Sets the HTTP client builder. + * @param clientBuilder the HTTP client builder + * @return this builder + */ + public Builder clientBuilder(HttpClient.Builder clientBuilder) { + Assert.notNull(clientBuilder, "clientBuilder must not be null"); + this.clientBuilder = clientBuilder; + return this; + } + + /** + * Sets the object mapper for JSON serialization/deserialization. + * @param objectMapper the object mapper + * @return this builder + */ + public Builder objectMapper(ObjectMapper objectMapper) { + Assert.notNull(objectMapper, "objectMapper must not be null"); + this.objectMapper = objectMapper; + return this; + } + + /** + * Builds a new {@link HttpClientSseClientTransport} instance. + * @return a new transport instance + */ + public HttpClientSseClientTransport build() { + return new HttpClientSseClientTransport(clientBuilder, baseUri, sseEndpoint, objectMapper); + } + + } + /** * Establishes the SSE connection with the server and sets up message handling. * @@ -137,7 +229,7 @@ public Mono connect(Function, Mono> h CompletableFuture future = new CompletableFuture<>(); connectionFuture.set(future); - sseClient.subscribe(this.baseUri + SSE_ENDPOINT, new FlowSseClient.SseEventHandler() { + sseClient.subscribe(this.baseUri + this.sseEndpoint, new FlowSseClient.SseEventHandler() { @Override public void onEvent(SseEvent event) { if (isClosing) { diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java index 152462b1d..a64b4a353 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java @@ -18,6 +18,7 @@ import io.modelcontextprotocol.spec.McpServerSession; import io.modelcontextprotocol.spec.McpServerTransport; import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.util.Assert; import jakarta.servlet.AsyncContext; import jakarta.servlet.ServletException; import jakarta.servlet.annotation.WebServlet; @@ -170,8 +171,8 @@ public Mono notifyClients(String method, Map params) { protected void doGet(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException { - String pathInfo = request.getPathInfo(); - if (!sseEndpoint.equals(pathInfo)) { + String requestURI = request.getRequestURI(); + if (!requestURI.endsWith(sseEndpoint)) { response.sendError(HttpServletResponse.SC_NOT_FOUND); return; } @@ -225,8 +226,8 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) return; } - String pathInfo = request.getPathInfo(); - if (!messageEndpoint.equals(pathInfo)) { + String requestURI = request.getRequestURI(); + if (!requestURI.endsWith(messageEndpoint)) { response.sendError(HttpServletResponse.SC_NOT_FOUND); return; } @@ -429,4 +430,81 @@ public void close() { } + /** + * Creates a new Builder instance for configuring and creating instances of + * HttpServletSseServerTransportProvider. + * @return A new Builder instance + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for creating instances of HttpServletSseServerTransportProvider. + *

+ * This builder provides a fluent API for configuring and creating instances of + * HttpServletSseServerTransportProvider with custom settings. + */ + public static class Builder { + + private ObjectMapper objectMapper = new ObjectMapper(); + + private String messageEndpoint; + + private String sseEndpoint = DEFAULT_SSE_ENDPOINT; + + /** + * Sets the JSON object mapper to use for message serialization/deserialization. + * @param objectMapper The object mapper to use + * @return This builder instance for method chaining + */ + public Builder objectMapper(ObjectMapper objectMapper) { + Assert.notNull(objectMapper, "ObjectMapper must not be null"); + this.objectMapper = objectMapper; + return this; + } + + /** + * Sets the endpoint path where clients will send their messages. + * @param messageEndpoint The message endpoint path + * @return This builder instance for method chaining + */ + public Builder messageEndpoint(String messageEndpoint) { + Assert.hasText(messageEndpoint, "Message endpoint must not be empty"); + this.messageEndpoint = messageEndpoint; + return this; + } + + /** + * Sets the endpoint path where clients will establish SSE connections. + *

+ * If not specified, the default value of {@link #DEFAULT_SSE_ENDPOINT} will be + * used. + * @param sseEndpoint The SSE endpoint path + * @return This builder instance for method chaining + */ + public Builder sseEndpoint(String sseEndpoint) { + Assert.hasText(sseEndpoint, "SSE endpoint must not be empty"); + this.sseEndpoint = sseEndpoint; + return this; + } + + /** + * Builds a new instance of HttpServletSseServerTransportProvider with the + * configured settings. + * @return A new HttpServletSseServerTransportProvider instance + * @throws IllegalStateException if objectMapper or messageEndpoint is not set + */ + public HttpServletSseServerTransportProvider build() { + if (objectMapper == null) { + throw new IllegalStateException("ObjectMapper must be set"); + } + if (messageEndpoint == null) { + throw new IllegalStateException("MessageEndpoint must be set"); + } + return new HttpServletSseServerTransportProvider(objectMapper, messageEndpoint, sseEndpoint); + } + + } + } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpAsyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpAsyncServerTests.java index 9de186b4b..81d904292 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpAsyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpAsyncServerTests.java @@ -4,7 +4,6 @@ package io.modelcontextprotocol.server; -import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.server.transport.HttpServletSseServerTransportProvider; import io.modelcontextprotocol.spec.McpServerTransportProvider; import org.junit.jupiter.api.Timeout; @@ -19,7 +18,7 @@ class ServletSseMcpAsyncServerTests extends AbstractMcpAsyncServerTests { @Override protected McpServerTransportProvider createMcpTransportProvider() { - return new HttpServletSseServerTransportProvider(new ObjectMapper(), "/mcp/message"); + return HttpServletSseServerTransportProvider.builder().messageEndpoint("/mcp/message").build(); } } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpSyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpSyncServerTests.java index 60dc53a4a..154cf3a61 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpSyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpSyncServerTests.java @@ -4,7 +4,6 @@ package io.modelcontextprotocol.server; -import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.server.transport.HttpServletSseServerTransportProvider; import io.modelcontextprotocol.spec.McpServerTransportProvider; import org.junit.jupiter.api.Timeout; @@ -19,7 +18,7 @@ class ServletSseMcpSyncServerTests extends AbstractMcpSyncServerTests { @Override protected McpServerTransportProvider createMcpTransportProvider() { - return new HttpServletSseServerTransportProvider(new ObjectMapper(), "/mcp/message"); + return HttpServletSseServerTransportProvider.builder().messageEndpoint("/mcp/message").build(); } } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java index fd8a4e9f9..1cd395e74 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java @@ -47,7 +47,9 @@ public class HttpServletSseServerTransportProviderIntegrationTests { private static final int PORT = 8185; - private static final String MESSAGE_ENDPOINT = "/mcp/message"; + private static final String CUSTOM_SSE_ENDPOINT = "/somePath/sse"; + + private static final String CUSTOM_MESSAGE_ENDPOINT = "/otherPath/mcp/message"; private HttpServletSseServerTransportProvider mcpServerTransportProvider; @@ -66,7 +68,11 @@ public void before() { Context context = tomcat.addContext("", baseDir); // Create and configure the transport provider - mcpServerTransportProvider = new HttpServletSseServerTransportProvider(new ObjectMapper(), MESSAGE_ENDPOINT); + mcpServerTransportProvider = HttpServletSseServerTransportProvider.builder() + .objectMapper(new ObjectMapper()) + .messageEndpoint(CUSTOM_MESSAGE_ENDPOINT) + .sseEndpoint(CUSTOM_SSE_ENDPOINT) + .build(); // Add transport servlet to Tomcat org.apache.catalina.Wrapper wrapper = context.createWrapper(); @@ -87,7 +93,9 @@ public void before() { throw new RuntimeException("Failed to start Tomcat", e); } - this.clientBuilder = McpClient.sync(new HttpClientSseClientTransport("http://localhost:" + PORT)); + this.clientBuilder = McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT) + .sseEndpoint(CUSTOM_SSE_ENDPOINT) + .build()); } @AfterEach From 55ee15604a5b2408992f37189f7a83843f5e759f Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Thu, 27 Mar 2025 14:04:33 +0100 Subject: [PATCH 004/205] feat(webflux): Add configurable SSE endpoints to WebFlux transport (#41, #67) Enhances WebFlux SSE transport implementation with customizable endpoint paths: - Add configurable SSE endpoint support in both client and server transports - Update tests to verify custom SSE endpoint functionality - Implement builder pattern to support the new configuration options Co-authored-by: haidao Co-authored-by: Harry <34418180+HarryFQG@users.noreply.github.com> Signed-off-by: Christian Tzolov --- .../transport/WebFluxSseClientTransport.java | 88 ++++++++++++++++++- .../WebFluxSseServerTransportProvider.java | 70 +++++++++++++++ .../WebFluxSseIntegrationTests.java | 23 +++-- .../client/WebFluxSseMcpAsyncClientTests.java | 2 +- .../client/WebFluxSseMcpSyncClientTests.java | 2 +- .../WebFluxSseClientTransportTests.java | 36 ++++++-- .../server/WebFluxSseMcpAsyncServerTests.java | 4 +- .../server/WebFluxSseMcpSyncServerTests.java | 4 +- 8 files changed, 210 insertions(+), 19 deletions(-) diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java index b0dfa89c0..37abe295b 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java @@ -79,7 +79,7 @@ public class WebFluxSseClientTransport implements McpClientTransport { * Default SSE endpoint path as specified by the MCP transport specification. This * endpoint is used to establish the SSE connection with the server. */ - private static final String SSE_ENDPOINT = "/sse"; + private static final String DEFAULT_SSE_ENDPOINT = "/sse"; /** * Type reference for parsing SSE events containing string data. @@ -117,6 +117,12 @@ public class WebFluxSseClientTransport implements McpClientTransport { */ protected final Sinks.One messageEndpointSink = Sinks.one(); + /** + * The SSE endpoint URI provided by the server. Used for sending outbound messages via + * HTTP POST requests. + */ + private String sseEndpoint; + /** * Constructs a new SseClientTransport with the specified WebClient builder. Uses a * default ObjectMapper instance for JSON processing. @@ -137,11 +143,27 @@ public WebFluxSseClientTransport(WebClient.Builder webClientBuilder) { * @throws IllegalArgumentException if either parameter is null */ public WebFluxSseClientTransport(WebClient.Builder webClientBuilder, ObjectMapper objectMapper) { + this(webClientBuilder, objectMapper, DEFAULT_SSE_ENDPOINT); + } + + /** + * Constructs a new SseClientTransport with the specified WebClient builder and + * ObjectMapper. Initializes both inbound and outbound message processing pipelines. + * @param webClientBuilder the WebClient.Builder to use for creating the WebClient + * instance + * @param objectMapper the ObjectMapper to use for JSON processing + * @param sseEndpoint the SSE endpoint URI to use for establishing the connection + * @throws IllegalArgumentException if either parameter is null + */ + public WebFluxSseClientTransport(WebClient.Builder webClientBuilder, ObjectMapper objectMapper, + String sseEndpoint) { Assert.notNull(objectMapper, "ObjectMapper must not be null"); Assert.notNull(webClientBuilder, "WebClient.Builder must not be null"); + Assert.hasText(sseEndpoint, "SSE endpoint must not be null or empty"); this.objectMapper = objectMapper; this.webClient = webClientBuilder.build(); + this.sseEndpoint = sseEndpoint; } /** @@ -254,7 +276,7 @@ public Mono sendMessage(JSONRPCMessage message) { protected Flux> eventStream() {// @formatter:off return this.webClient .get() - .uri(SSE_ENDPOINT) + .uri(this.sseEndpoint) .accept(MediaType.TEXT_EVENT_STREAM) .retrieve() .bodyToFlux(SSE_TYPE) @@ -321,4 +343,66 @@ public T unmarshalFrom(Object data, TypeReference typeRef) { return this.objectMapper.convertValue(data, typeRef); } + /** + * Creates a new builder for {@link WebFluxSseClientTransport}. + * @param webClientBuilder the WebClient.Builder to use for creating the WebClient + * instance + * @return a new builder instance + */ + public static Builder builder(WebClient.Builder webClientBuilder) { + return new Builder(webClientBuilder); + } + + /** + * Builder for {@link WebFluxSseClientTransport}. + */ + public static class Builder { + + private final WebClient.Builder webClientBuilder; + + private String sseEndpoint = DEFAULT_SSE_ENDPOINT; + + private ObjectMapper objectMapper = new ObjectMapper(); + + /** + * Creates a new builder with the specified WebClient.Builder. + * @param webClientBuilder the WebClient.Builder to use + */ + public Builder(WebClient.Builder webClientBuilder) { + Assert.notNull(webClientBuilder, "WebClient.Builder must not be null"); + this.webClientBuilder = webClientBuilder; + } + + /** + * Sets the SSE endpoint path. + * @param sseEndpoint the SSE endpoint path + * @return this builder + */ + public Builder sseEndpoint(String sseEndpoint) { + Assert.hasText(sseEndpoint, "sseEndpoint must not be empty"); + this.sseEndpoint = sseEndpoint; + return this; + } + + /** + * Sets the object mapper for JSON serialization/deserialization. + * @param objectMapper the object mapper + * @return this builder + */ + public Builder objectMapper(ObjectMapper objectMapper) { + Assert.notNull(objectMapper, "objectMapper must not be null"); + this.objectMapper = objectMapper; + return this; + } + + /** + * Builds a new {@link WebFluxSseClientTransport} instance. + * @return a new transport instance + */ + public WebFluxSseClientTransport build() { + return new WebFluxSseClientTransport(webClientBuilder, objectMapper, sseEndpoint); + } + + } + } diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java index 4e5d2fafb..85a39a82f 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java @@ -346,4 +346,74 @@ public void close() { } + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for creating instances of {@link WebFluxSseServerTransportProvider}. + *

+ * This builder provides a fluent API for configuring and creating instances of + * WebFluxSseServerTransportProvider with custom settings. + */ + public static class Builder { + + private ObjectMapper objectMapper; + + private String messageEndpoint; + + private String sseEndpoint = DEFAULT_SSE_ENDPOINT; + + /** + * Sets the ObjectMapper to use for JSON serialization/deserialization of MCP + * messages. + * @param objectMapper The ObjectMapper instance. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if objectMapper is null + */ + public Builder objectMapper(ObjectMapper objectMapper) { + Assert.notNull(objectMapper, "ObjectMapper must not be null"); + this.objectMapper = objectMapper; + return this; + } + + /** + * Sets the endpoint URI where clients should send their JSON-RPC messages. + * @param messageEndpoint The message endpoint URI. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if messageEndpoint is null + */ + public Builder messageEndpoint(String messageEndpoint) { + Assert.notNull(messageEndpoint, "Message endpoint must not be null"); + this.messageEndpoint = messageEndpoint; + return this; + } + + /** + * Sets the SSE endpoint path. + * @param sseEndpoint The SSE endpoint path. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if sseEndpoint is null + */ + public Builder sseEndpoint(String sseEndpoint) { + Assert.notNull(sseEndpoint, "SSE endpoint must not be null"); + this.sseEndpoint = sseEndpoint; + return this; + } + + /** + * Builds a new instance of {@link WebFluxSseServerTransportProvider} with the + * configured settings. + * @return A new WebFluxSseServerTransportProvider instance + * @throws IllegalStateException if required parameters are not set + */ + public WebFluxSseServerTransportProvider build() { + Assert.notNull(objectMapper, "ObjectMapper must be set"); + Assert.notNull(messageEndpoint, "Message endpoint must be set"); + + return new WebFluxSseServerTransportProvider(objectMapper, messageEndpoint, sseEndpoint); + } + + } + } diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java index 2d9d055f3..2be2f81f2 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java @@ -46,14 +46,17 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.awaitility.Awaitility.await; -import static org.junit.Assert.assertThat; import static org.mockito.Mockito.mock; public class WebFluxSseIntegrationTests { private static final int PORT = 8182; - private static final String MESSAGE_ENDPOINT = "/mcp/message"; + // private static final String MESSAGE_ENDPOINT = "/mcp/message"; + + private static final String CUSTOM_SSE_ENDPOINT = "/somePath/sse"; + + private static final String CUSTOM_MESSAGE_ENDPOINT = "/otherPath/mcp/message"; private DisposableServer httpServer; @@ -64,15 +67,25 @@ public class WebFluxSseIntegrationTests { @BeforeEach public void before() { - this.mcpServerTransportProvider = new WebFluxSseServerTransportProvider(new ObjectMapper(), MESSAGE_ENDPOINT); + this.mcpServerTransportProvider = new WebFluxSseServerTransportProvider.Builder() + .objectMapper(new ObjectMapper()) + .messageEndpoint(CUSTOM_MESSAGE_ENDPOINT) + .sseEndpoint(CUSTOM_SSE_ENDPOINT) + .build(); HttpHandler httpHandler = RouterFunctions.toHttpHandler(mcpServerTransportProvider.getRouterFunction()); ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); this.httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); - clientBulders.put("httpclient", McpClient.sync(new HttpClientSseClientTransport("http://localhost:" + PORT))); + clientBulders.put("httpclient", + McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT) + .sseEndpoint(CUSTOM_SSE_ENDPOINT) + .build())); clientBulders.put("webflux", - McpClient.sync(new WebFluxSseClientTransport(WebClient.builder().baseUrl("http://localhost:" + PORT)))); + McpClient + .sync(WebFluxSseClientTransport.builder(WebClient.builder().baseUrl("http://localhost:" + PORT)) + .sseEndpoint(CUSTOM_SSE_ENDPOINT) + .build())); } diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java index 2dd587d4f..b43c14493 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java @@ -33,7 +33,7 @@ class WebFluxSseMcpAsyncClientTests extends AbstractMcpAsyncClientTests { @Override protected McpClientTransport createMcpTransport() { - return new WebFluxSseClientTransport(WebClient.builder().baseUrl(host)); + return WebFluxSseClientTransport.builder(WebClient.builder().baseUrl(host)).build(); } @Override diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java index 72b390ddd..66ac8a6dd 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java @@ -33,7 +33,7 @@ class WebFluxSseMcpSyncClientTests extends AbstractMcpSyncClientTests { @Override protected McpClientTransport createMcpTransport() { - return new WebFluxSseClientTransport(WebClient.builder().baseUrl(host)); + return WebFluxSseClientTransport.builder(WebClient.builder().baseUrl(host)).build(); } @Override diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java index 912e04f14..c757d3da9 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java @@ -63,13 +63,6 @@ public TestSseClientTransport(WebClient.Builder webClientBuilder, ObjectMapper o super(webClientBuilder, objectMapper); } - // @Override - // public Mono connect(Function, - // Mono> handler) { - // simulateEndpointEvent("https://localhost:3001"); - // return super.connect(handler); - // } - @Override protected Flux> eventStream() { return super.eventStream().mergeWith(events.asFlux()); @@ -137,6 +130,33 @@ void constructorValidation() { .hasMessageContaining("ObjectMapper must not be null"); } + @Test + void testBuilderPattern() { + // Test default builder + WebFluxSseClientTransport transport1 = WebFluxSseClientTransport.builder(webClientBuilder).build(); + assertThatCode(() -> transport1.closeGracefully().block()).doesNotThrowAnyException(); + + // Test builder with custom ObjectMapper + ObjectMapper customMapper = new ObjectMapper(); + WebFluxSseClientTransport transport2 = WebFluxSseClientTransport.builder(webClientBuilder) + .objectMapper(customMapper) + .build(); + assertThatCode(() -> transport2.closeGracefully().block()).doesNotThrowAnyException(); + + // Test builder with custom SSE endpoint + WebFluxSseClientTransport transport3 = WebFluxSseClientTransport.builder(webClientBuilder) + .sseEndpoint("/custom-sse") + .build(); + assertThatCode(() -> transport3.closeGracefully().block()).doesNotThrowAnyException(); + + // Test builder with all custom parameters + WebFluxSseClientTransport transport4 = WebFluxSseClientTransport.builder(webClientBuilder) + .objectMapper(customMapper) + .sseEndpoint("/custom-sse") + .build(); + assertThatCode(() -> transport4.closeGracefully().block()).doesNotThrowAnyException(); + } + @Test void testMessageProcessing() { // Create a test message @@ -240,7 +260,7 @@ void testRetryBehavior() { // Create a WebClient that simulates connection failures WebClient.Builder failingWebClientBuilder = WebClient.builder().baseUrl("http://non-existent-host"); - WebFluxSseClientTransport failingTransport = new WebFluxSseClientTransport(failingWebClientBuilder); + WebFluxSseClientTransport failingTransport = WebFluxSseClientTransport.builder(failingWebClientBuilder).build(); // Verify that the transport attempts to reconnect StepVerifier.create(Mono.delay(Duration.ofSeconds(2))).expectNextCount(1).verifyComplete(); diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerTests.java index 5fa787ab6..98844c741 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerTests.java @@ -31,7 +31,9 @@ class WebFluxSseMcpAsyncServerTests extends AbstractMcpAsyncServerTests { @Override protected McpServerTransportProvider createMcpTransportProvider() { - var transportProvider = new WebFluxSseServerTransportProvider(new ObjectMapper(), MESSAGE_ENDPOINT); + var transportProvider = new WebFluxSseServerTransportProvider.Builder().objectMapper(new ObjectMapper()) + .messageEndpoint(MESSAGE_ENDPOINT) + .build(); HttpHandler httpHandler = RouterFunctions.toHttpHandler(transportProvider.getRouterFunction()); ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerTests.java index d3672e3f3..71072855e 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerTests.java @@ -33,7 +33,9 @@ class WebFluxSseMcpSyncServerTests extends AbstractMcpSyncServerTests { @Override protected McpServerTransportProvider createMcpTransportProvider() { - transportProvider = new WebFluxSseServerTransportProvider(new ObjectMapper(), MESSAGE_ENDPOINT); + transportProvider = new WebFluxSseServerTransportProvider.Builder().objectMapper(new ObjectMapper()) + .messageEndpoint(MESSAGE_ENDPOINT) + .build(); return transportProvider; } From e401e9e848cb6651451ed85c5fd7870727682031 Mon Sep 17 00:00:00 2001 From: codeboyzhou Date: Tue, 25 Mar 2025 14:57:43 +0800 Subject: [PATCH 005/205] feat(tests): Add unit tests for Assert and Utils classes (#70) Signed-off-by: Christian Tzolov --- .../util/AssertTests.java | 46 +++++++++++++++++++ .../modelcontextprotocol/util/UtilsTests.java | 40 ++++++++++++++++ 2 files changed, 86 insertions(+) create mode 100644 mcp/src/test/java/io/modelcontextprotocol/util/AssertTests.java create mode 100644 mcp/src/test/java/io/modelcontextprotocol/util/UtilsTests.java diff --git a/mcp/src/test/java/io/modelcontextprotocol/util/AssertTests.java b/mcp/src/test/java/io/modelcontextprotocol/util/AssertTests.java new file mode 100644 index 000000000..08555fef5 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/util/AssertTests.java @@ -0,0 +1,46 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.util; + +import org.junit.jupiter.api.Test; + +import java.util.List; + +import static org.junit.jupiter.api.Assertions.*; + +class AssertTests { + + @Test + void testCollectionNotEmpty() { + IllegalArgumentException e1 = assertThrows(IllegalArgumentException.class, + () -> Assert.notEmpty(null, "collection is null")); + assertEquals("collection is null", e1.getMessage()); + + IllegalArgumentException e2 = assertThrows(IllegalArgumentException.class, + () -> Assert.notEmpty(List.of(), "collection is empty")); + assertEquals("collection is empty", e2.getMessage()); + + assertDoesNotThrow(() -> Assert.notEmpty(List.of("test"), "collection is not empty")); + } + + @Test + void testObjectNotNull() { + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> Assert.notNull(null, "object is null")); + assertEquals("object is null", e.getMessage()); + + assertDoesNotThrow(() -> Assert.notNull("test", "object is not null")); + } + + @Test + void testStringHasText() { + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> Assert.hasText(null, "string is null")); + assertEquals("string is null", e.getMessage()); + + assertDoesNotThrow(() -> Assert.hasText("test", "string is not empty")); + } + +} \ No newline at end of file diff --git a/mcp/src/test/java/io/modelcontextprotocol/util/UtilsTests.java b/mcp/src/test/java/io/modelcontextprotocol/util/UtilsTests.java new file mode 100644 index 000000000..aced20cbc --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/util/UtilsTests.java @@ -0,0 +1,40 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.util; + +import org.junit.jupiter.api.Test; + +import java.util.Collection; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class UtilsTests { + + @Test + void testHasText() { + assertFalse(Utils.hasText(null)); + assertFalse(Utils.hasText("")); + assertFalse(Utils.hasText(" ")); + assertTrue(Utils.hasText("test")); + } + + @Test + void testCollectionIsEmpty() { + assertTrue(Utils.isEmpty((Collection) null)); + assertTrue(Utils.isEmpty(List.of())); + assertFalse(Utils.isEmpty(List.of("test"))); + } + + @Test + void testMapIsEmpty() { + assertTrue(Utils.isEmpty((Map) null)); + assertTrue(Utils.isEmpty(Map.of())); + assertFalse(Utils.isEmpty(Map.of("key", "value"))); + } + +} \ No newline at end of file From 79ec5b5ed1cc1a7abf2edda313a81875bd75ad86 Mon Sep 17 00:00:00 2001 From: codeboyzz Date: Sat, 29 Mar 2025 13:08:23 +0800 Subject: [PATCH 006/205] fix(tests): Failed to start process with command npx on Windows (#85) * fix(tests): Failed to start process with command npx on Windows platform while running mvn test --- .../client/StdioMcpAsyncClientTests.java | 14 +++++++++++--- .../client/StdioMcpSyncClientTests.java | 15 +++++++++++---- 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java index 95230942c..c39080138 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java @@ -22,9 +22,17 @@ class StdioMcpAsyncClientTests extends AbstractMcpAsyncClientTests { @Override protected McpClientTransport createMcpTransport() { - ServerParameters stdioParams = ServerParameters.builder("npx") - .args("-y", "@modelcontextprotocol/server-everything", "dir") - .build(); + ServerParameters stdioParams; + if (System.getProperty("os.name").toLowerCase().contains("win")) { + stdioParams = ServerParameters.builder("cmd.exe") + .args("/c", "npx.cmd", "-y", "@modelcontextprotocol/server-everything", "dir") + .build(); + } + else { + stdioParams = ServerParameters.builder("npx") + .args("-y", "@modelcontextprotocol/server-everything", "dir") + .build(); + } return new StdioClientTransport(stdioParams); } diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java index 925852b5b..8e75c4a3d 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java @@ -30,10 +30,17 @@ class StdioMcpSyncClientTests extends AbstractMcpSyncClientTests { @Override protected McpClientTransport createMcpTransport() { - ServerParameters stdioParams = ServerParameters.builder("npx") - .args("-y", "@modelcontextprotocol/server-everything", "dir") - .build(); - + ServerParameters stdioParams; + if (System.getProperty("os.name").toLowerCase().contains("win")) { + stdioParams = ServerParameters.builder("cmd.exe") + .args("/c", "npx.cmd", "-y", "@modelcontextprotocol/server-everything", "dir") + .build(); + } + else { + stdioParams = ServerParameters.builder("npx") + .args("-y", "@modelcontextprotocol/server-everything", "dir") + .build(); + } return new StdioClientTransport(stdioParams); } From 15a55b6b62945b6fd554826385eef92409b1d522 Mon Sep 17 00:00:00 2001 From: codezjx Date: Sat, 5 Apr 2025 19:13:15 +0800 Subject: [PATCH 007/205] fix: add support to set instructions as mentioned in #98 (#99) --- .../server/McpAsyncServer.java | 5 ++- .../server/McpServer.java | 33 +++++++++++++++++-- .../server/McpServerFeatures.java | 19 ++++++++--- 3 files changed, 49 insertions(+), 8 deletions(-) diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java index 188b0f48e..df9386685 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -247,6 +247,8 @@ private static class AsyncServerImpl extends McpAsyncServer { private final McpSchema.Implementation serverInfo; + private final String instructions; + private final CopyOnWriteArrayList tools = new CopyOnWriteArrayList<>(); private final CopyOnWriteArrayList resourceTemplates = new CopyOnWriteArrayList<>(); @@ -265,6 +267,7 @@ private static class AsyncServerImpl extends McpAsyncServer { this.objectMapper = objectMapper; this.serverInfo = features.serverInfo(); this.serverCapabilities = features.serverCapabilities(); + this.instructions = features.instructions(); this.tools.addAll(features.tools()); this.resources.putAll(features.resources()); this.resourceTemplates.addAll(features.resourceTemplates()); @@ -351,7 +354,7 @@ private Mono asyncInitializeRequestHandler( } return Mono.just(new McpSchema.InitializeResult(serverProtocolVersion, this.serverCapabilities, - this.serverInfo, null)); + this.serverInfo, this.instructions)); }); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java index 091efac2f..d5427335d 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java @@ -160,6 +160,8 @@ class AsyncSpecification { private McpSchema.ServerCapabilities serverCapabilities; + private String instructions; + /** * The Model Context Protocol (MCP) allows servers to expose tools that can be * invoked by language models. Tools enable models to interact with external @@ -228,6 +230,18 @@ public AsyncSpecification serverInfo(String name, String version) { return this; } + /** + * Sets the server instructions that will be shared with clients during connection + * initialization. These instructions provide guidance to the client on how to + * interact with this server. + * @param instructions The instructions text. Can be null or empty. + * @return This builder instance for method chaining + */ + public AsyncSpecification instructions(String instructions) { + this.instructions = instructions; + return this; + } + /** * Sets the server capabilities that will be advertised to clients during * connection initialization. Capabilities define what features the server @@ -549,7 +563,7 @@ public AsyncSpecification objectMapper(ObjectMapper objectMapper) { */ public McpAsyncServer build() { var features = new McpServerFeatures.Async(this.serverInfo, this.serverCapabilities, this.tools, - this.resources, this.resourceTemplates, this.prompts, this.rootsChangeHandlers); + this.resources, this.resourceTemplates, this.prompts, this.rootsChangeHandlers, this.instructions); var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); return new McpAsyncServer(this.transportProvider, mapper, features); } @@ -572,6 +586,8 @@ class SyncSpecification { private McpSchema.ServerCapabilities serverCapabilities; + private String instructions; + /** * The Model Context Protocol (MCP) allows servers to expose tools that can be * invoked by language models. Tools enable models to interact with external @@ -640,6 +656,18 @@ public SyncSpecification serverInfo(String name, String version) { return this; } + /** + * Sets the server instructions that will be shared with clients during connection + * initialization. These instructions provide guidance to the client on how to + * interact with this server. + * @param instructions The instructions text. Can be null or empty. + * @return This builder instance for method chaining + */ + public SyncSpecification instructions(String instructions) { + this.instructions = instructions; + return this; + } + /** * Sets the server capabilities that will be advertised to clients during * connection initialization. Capabilities define what features the server @@ -960,7 +988,8 @@ public SyncSpecification objectMapper(ObjectMapper objectMapper) { */ public McpSyncServer build() { McpServerFeatures.Sync syncFeatures = new McpServerFeatures.Sync(this.serverInfo, this.serverCapabilities, - this.tools, this.resources, this.resourceTemplates, this.prompts, this.rootsChangeHandlers); + this.tools, this.resources, this.resourceTemplates, this.prompts, this.rootsChangeHandlers, + this.instructions); McpServerFeatures.Async asyncFeatures = McpServerFeatures.Async.fromSync(syncFeatures); var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); var asyncServer = new McpAsyncServer(this.transportProvider, mapper, asyncFeatures); diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java index 8c110027c..e0f337b78 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java @@ -35,12 +35,14 @@ public class McpServerFeatures { * @param prompts The map of prompt specifications * @param rootsChangeConsumers The list of consumers that will be notified when the * roots list changes + * @param instructions The server instructions text */ record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, List tools, Map resources, List resourceTemplates, Map prompts, - List, Mono>> rootsChangeConsumers) { + List, Mono>> rootsChangeConsumers, + String instructions) { /** * Create an instance and validate the arguments. @@ -52,12 +54,14 @@ record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities s * @param prompts The map of prompt specifications * @param rootsChangeConsumers The list of consumers that will be notified when * the roots list changes + * @param instructions The server instructions text */ Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, List tools, Map resources, List resourceTemplates, Map prompts, - List, Mono>> rootsChangeConsumers) { + List, Mono>> rootsChangeConsumers, + String instructions) { Assert.notNull(serverInfo, "Server info must not be null"); @@ -78,6 +82,7 @@ record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities s this.resourceTemplates = (resourceTemplates != null) ? resourceTemplates : List.of(); this.prompts = (prompts != null) ? prompts : Map.of(); this.rootsChangeConsumers = (rootsChangeConsumers != null) ? rootsChangeConsumers : List.of(); + this.instructions = instructions; } /** @@ -113,7 +118,7 @@ static Async fromSync(Sync syncSpec) { } return new Async(syncSpec.serverInfo(), syncSpec.serverCapabilities(), tools, resources, - syncSpec.resourceTemplates(), prompts, rootChangeConsumers); + syncSpec.resourceTemplates(), prompts, rootChangeConsumers, syncSpec.instructions()); } } @@ -128,13 +133,14 @@ static Async fromSync(Sync syncSpec) { * @param prompts The map of prompt specifications * @param rootsChangeConsumers The list of consumers that will be notified when the * roots list changes + * @param instructions The server instructions text */ record Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, List tools, Map resources, List resourceTemplates, Map prompts, - List>> rootsChangeConsumers) { + List>> rootsChangeConsumers, String instructions) { /** * Create an instance and validate the arguments. @@ -146,13 +152,15 @@ record Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities se * @param prompts The map of prompt specifications * @param rootsChangeConsumers The list of consumers that will be notified when * the roots list changes + * @param instructions The server instructions text */ Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, List tools, Map resources, List resourceTemplates, Map prompts, - List>> rootsChangeConsumers) { + List>> rootsChangeConsumers, + String instructions) { Assert.notNull(serverInfo, "Server info must not be null"); @@ -173,6 +181,7 @@ record Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities se this.resourceTemplates = (resourceTemplates != null) ? resourceTemplates : new ArrayList<>(); this.prompts = (prompts != null) ? prompts : new HashMap<>(); this.rootsChangeConsumers = (rootsChangeConsumers != null) ? rootsChangeConsumers : new ArrayList<>(); + this.instructions = instructions; } } From bda3cab843c5d0f189919c91c78d8928b902b10f Mon Sep 17 00:00:00 2001 From: JermaineHua Date: Sat, 22 Mar 2025 22:14:38 +0800 Subject: [PATCH 008/205] Fix MCP schema link error Signed-off-by: JermaineHua --- mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java index 37d9e0c0a..7749cd93d 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java @@ -24,7 +24,7 @@ /** * Based on the JSON-RPC 2.0 * specification and the Model + * "https://github.com/modelcontextprotocol/specification/blob/main/schema/2024-11-05/schema.ts">Model * Context Protocol Schema. * * @author Christian Tzolov From 8d5872fd666b4c32d5d29318b389c42f41d968c0 Mon Sep 17 00:00:00 2001 From: Dennis Kawurek Date: Sun, 30 Mar 2025 20:49:04 +0200 Subject: [PATCH 009/205] feat(McpSchema): CallToolResult and CallToolRequest usability improvements (#87) - Add constructor to CallToolResult with one String entry - Add a new constructor to CallToolRequest that accepts JSON string arguments - Implement a builder pattern for CallToolResult with methods for adding content items - Add test coverage for new functionality Signed-off-by: Christian Tzolov Co-authored-by: Christian Tzolov --- .../modelcontextprotocol/spec/McpSchema.java | 111 +++++++++++++++++ .../spec/McpSchemaTests.java | 112 ++++++++++++++++++ 2 files changed, 223 insertions(+) diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java index 7749cd93d..e38403c32 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java @@ -18,6 +18,7 @@ import com.fasterxml.jackson.annotation.JsonTypeInfo.As; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.util.Assert; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -741,6 +742,19 @@ private static JsonSchema parseSchema(String schema) { public record CallToolRequest(// @formatter:off @JsonProperty("name") String name, @JsonProperty("arguments") Map arguments) implements Request { + + public CallToolRequest(String name, String jsonArguments) { + this(name, parseJsonArguments(jsonArguments)); + } + + private static Map parseJsonArguments(String jsonArguments) { + try { + return OBJECT_MAPPER.readValue(jsonArguments, MAP_TYPE_REF); + } + catch (IOException e) { + throw new IllegalArgumentException("Invalid arguments: " + jsonArguments, e); + } + } }// @formatter:off /** @@ -756,6 +770,103 @@ public record CallToolRequest(// @formatter:off public record CallToolResult( // @formatter:off @JsonProperty("content") List content, @JsonProperty("isError") Boolean isError) { + + /** + * Creates a new instance of {@link CallToolResult} with a string containing the + * tool result. + * + * @param content The content of the tool result. This will be mapped to a one-sized list + * with a {@link TextContent} element. + * @param isError If true, indicates that the tool execution failed and the content contains error information. + * If false or absent, indicates successful execution. + */ + public CallToolResult(String content, Boolean isError) { + this(List.of(new TextContent(content)), isError); + } + + /** + * Creates a builder for {@link CallToolResult}. + * @return a new builder instance + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for {@link CallToolResult}. + */ + public static class Builder { + private List content = new ArrayList<>(); + private Boolean isError; + + /** + * Sets the content list for the tool result. + * @param content the content list + * @return this builder + */ + public Builder content(List content) { + Assert.notNull(content, "content must not be null"); + this.content = content; + return this; + } + + /** + * Sets the text content for the tool result. + * @param textContent the text content + * @return this builder + */ + public Builder textContent(List textContent) { + Assert.notNull(textContent, "textContent must not be null"); + textContent.stream() + .map(TextContent::new) + .forEach(this.content::add); + return this; + } + + /** + * Adds a content item to the tool result. + * @param contentItem the content item to add + * @return this builder + */ + public Builder addContent(Content contentItem) { + Assert.notNull(contentItem, "contentItem must not be null"); + if (this.content == null) { + this.content = new ArrayList<>(); + } + this.content.add(contentItem); + return this; + } + + /** + * Adds a text content item to the tool result. + * @param text the text content + * @return this builder + */ + public Builder addTextContent(String text) { + Assert.notNull(text, "text must not be null"); + return addContent(new TextContent(text)); + } + + /** + * Sets whether the tool execution resulted in an error. + * @param isError true if the tool execution failed, false otherwise + * @return this builder + */ + public Builder isError(Boolean isError) { + Assert.notNull(isError, "isError must not be null"); + this.isError = isError; + return this; + } + + /** + * Builds a new {@link CallToolResult} instance. + * @return a new CallToolResult instance + */ + public CallToolResult build() { + return new CallToolResult(content, isError); + } + } + } // @formatter:on // --------------------------- diff --git a/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java b/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java index 1b8adc33b..a41fc095f 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java @@ -6,6 +6,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.HashMap; +import java.util.List; import java.util.Map; import com.fasterxml.jackson.databind.ObjectMapper; @@ -493,6 +494,25 @@ void testCallToolRequest() throws Exception { {"name":"test-tool","arguments":{"name":"test","value":42}}""")); } + @Test + void testCallToolRequestJsonArguments() throws Exception { + + McpSchema.CallToolRequest request = new McpSchema.CallToolRequest("test-tool", """ + { + "name": "test", + "value": 42 + } + """); + + String value = mapper.writeValueAsString(request); + + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"name":"test-tool","arguments":{"name":"test","value":42}}""")); + } + @Test void testCallToolResult() throws Exception { McpSchema.TextContent content = new McpSchema.TextContent("Tool execution result"); @@ -508,6 +528,98 @@ void testCallToolResult() throws Exception { {"content":[{"type":"text","text":"Tool execution result"}],"isError":false}""")); } + @Test + void testCallToolResultBuilder() throws Exception { + McpSchema.CallToolResult result = McpSchema.CallToolResult.builder() + .addTextContent("Tool execution result") + .isError(false) + .build(); + + String value = mapper.writeValueAsString(result); + + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"content":[{"type":"text","text":"Tool execution result"}],"isError":false}""")); + } + + @Test + void testCallToolResultBuilderWithMultipleContents() throws Exception { + McpSchema.TextContent textContent = new McpSchema.TextContent("Text result"); + McpSchema.ImageContent imageContent = new McpSchema.ImageContent(null, null, "base64data", "image/png"); + + McpSchema.CallToolResult result = McpSchema.CallToolResult.builder() + .addContent(textContent) + .addContent(imageContent) + .isError(false) + .build(); + + String value = mapper.writeValueAsString(result); + + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"content":[{"type":"text","text":"Text result"},{"type":"image","data":"base64data","mimeType":"image/png"}],"isError":false}""")); + } + + @Test + void testCallToolResultBuilderWithContentList() throws Exception { + McpSchema.TextContent textContent = new McpSchema.TextContent("Text result"); + McpSchema.ImageContent imageContent = new McpSchema.ImageContent(null, null, "base64data", "image/png"); + List contents = Arrays.asList(textContent, imageContent); + + McpSchema.CallToolResult result = McpSchema.CallToolResult.builder().content(contents).isError(true).build(); + + String value = mapper.writeValueAsString(result); + + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"content":[{"type":"text","text":"Text result"},{"type":"image","data":"base64data","mimeType":"image/png"}],"isError":true}""")); + } + + @Test + void testCallToolResultBuilderWithErrorResult() throws Exception { + McpSchema.CallToolResult result = McpSchema.CallToolResult.builder() + .addTextContent("Error: Operation failed") + .isError(true) + .build(); + + String value = mapper.writeValueAsString(result); + + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"content":[{"type":"text","text":"Error: Operation failed"}],"isError":true}""")); + } + + @Test + void testCallToolResultStringConstructor() throws Exception { + // Test the existing string constructor alongside the builder + McpSchema.CallToolResult result1 = new McpSchema.CallToolResult("Simple result", false); + McpSchema.CallToolResult result2 = McpSchema.CallToolResult.builder() + .addTextContent("Simple result") + .isError(false) + .build(); + + String value1 = mapper.writeValueAsString(result1); + String value2 = mapper.writeValueAsString(result2); + + // Both should produce the same JSON + assertThat(value1).isEqualTo(value2); + assertThatJson(value1).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"content":[{"type":"text","text":"Simple result"}],"isError":false}""")); + } + // Sampling Tests @Test From eb8e3744a7903932ab02982e506d6a75e79fe3ab Mon Sep 17 00:00:00 2001 From: Renxia Wang Date: Sat, 29 Mar 2025 20:49:51 -0400 Subject: [PATCH 010/205] feat(transport): Add customizable HTTP request builder support (#86) Enhances FlowSseClient and HttpClientSseClientTransport to accept a custom HttpRequest.Builder, allowing for greater flexibility when configuring HTTP requests. This enables clients to customize headers, timeouts, and other request properties across all SSE connections and message sending operations. Signed-off-by: Christian Tzolov --- .../client/transport/FlowSseClient.java | 15 ++++++- .../HttpClientSseClientTransport.java | 41 +++++++++++++++++-- 2 files changed, 50 insertions(+), 6 deletions(-) diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/FlowSseClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/FlowSseClient.java index 7fc679937..50af35c70 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/FlowSseClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/FlowSseClient.java @@ -39,6 +39,8 @@ public class FlowSseClient { private final HttpClient httpClient; + private final HttpRequest.Builder requestBuilder; + /** * Pattern to extract the data content from SSE data field lines. Matches lines * starting with "data:" and captures the remaining content. @@ -92,7 +94,17 @@ public interface SseEventHandler { * @param httpClient the {@link HttpClient} instance to use for SSE connections */ public FlowSseClient(HttpClient httpClient) { + this(httpClient, HttpRequest.newBuilder()); + } + + /** + * Creates a new FlowSseClient with the specified HTTP client and request builder. + * @param httpClient the {@link HttpClient} instance to use for SSE connections + * @param requestBuilder the {@link HttpRequest.Builder} to use for SSE requests + */ + public FlowSseClient(HttpClient httpClient, HttpRequest.Builder requestBuilder) { this.httpClient = httpClient; + this.requestBuilder = requestBuilder; } /** @@ -109,8 +121,7 @@ public FlowSseClient(HttpClient httpClient) { * @throws RuntimeException if the connection fails with a non-200 status code */ public void subscribe(String url, SseEventHandler eventHandler) { - HttpRequest request = HttpRequest.newBuilder() - .uri(URI.create(url)) + HttpRequest request = this.requestBuilder.uri(URI.create(url)) .header("Accept", "text/event-stream") .header("Cache-Control", "no-cache") .GET() diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java index 696efdffd..0b482533d 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java @@ -82,6 +82,9 @@ public class HttpClientSseClientTransport implements McpClientTransport { */ private final HttpClient httpClient; + /** HTTP request builder for building requests to send messages to the server */ + private final HttpRequest.Builder requestBuilder; + /** JSON object mapper for message serialization/deserialization */ protected ObjectMapper objectMapper; @@ -126,15 +129,33 @@ public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, String bas */ public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, String baseUri, String sseEndpoint, ObjectMapper objectMapper) { + this(clientBuilder, HttpRequest.newBuilder(), baseUri, sseEndpoint, objectMapper); + } + + /** + * Creates a new transport instance with custom HTTP client builder, object mapper, + * and headers. + * @param clientBuilder the HTTP client builder to use + * @param requestBuilder the HTTP request builder to use + * @param baseUri the base URI of the MCP server + * @param sseEndpoint the SSE endpoint path + * @param objectMapper the object mapper for JSON serialization/deserialization + * @throws IllegalArgumentException if objectMapper, clientBuilder, or headers is null + */ + public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, HttpRequest.Builder requestBuilder, + String baseUri, String sseEndpoint, ObjectMapper objectMapper) { Assert.notNull(objectMapper, "ObjectMapper must not be null"); Assert.hasText(baseUri, "baseUri must not be empty"); Assert.hasText(sseEndpoint, "sseEndpoint must not be empty"); Assert.notNull(clientBuilder, "clientBuilder must not be null"); + Assert.notNull(requestBuilder, "requestBuilder must not be null"); this.baseUri = baseUri; this.sseEndpoint = sseEndpoint; this.objectMapper = objectMapper; this.httpClient = clientBuilder.connectTimeout(Duration.ofSeconds(10)).build(); - this.sseClient = new FlowSseClient(this.httpClient); + this.requestBuilder = requestBuilder; + + this.sseClient = new FlowSseClient(this.httpClient, requestBuilder); } /** @@ -159,6 +180,8 @@ public static class Builder { private ObjectMapper objectMapper = new ObjectMapper(); + private HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(); + /** * Creates a new builder with the specified base URI. * @param baseUri the base URI of the MCP server @@ -190,6 +213,17 @@ public Builder clientBuilder(HttpClient.Builder clientBuilder) { return this; } + /** + * Sets the HTTP request builder. + * @param requestBuilder the HTTP request builder + * @return this builder + */ + public Builder requestBuilder(HttpRequest.Builder requestBuilder) { + Assert.notNull(requestBuilder, "requestBuilder must not be null"); + this.requestBuilder = requestBuilder; + return this; + } + /** * Sets the object mapper for JSON serialization/deserialization. * @param objectMapper the object mapper @@ -206,7 +240,7 @@ public Builder objectMapper(ObjectMapper objectMapper) { * @return a new transport instance */ public HttpClientSseClientTransport build() { - return new HttpClientSseClientTransport(clientBuilder, baseUri, sseEndpoint, objectMapper); + return new HttpClientSseClientTransport(clientBuilder, requestBuilder, baseUri, sseEndpoint, objectMapper); } } @@ -301,8 +335,7 @@ public Mono sendMessage(JSONRPCMessage message) { try { String jsonText = this.objectMapper.writeValueAsString(message); - HttpRequest request = HttpRequest.newBuilder() - .uri(URI.create(this.baseUri + endpoint)) + HttpRequest request = this.requestBuilder.uri(URI.create(this.baseUri + endpoint)) .header("Content-Type", "application/json") .POST(HttpRequest.BodyPublishers.ofString(jsonText)) .build(); From c3a7c1ac1e04c141e95df1b1a77dc127b7ce0311 Mon Sep 17 00:00:00 2001 From: jitokim Date: Sun, 6 Apr 2025 02:12:34 +0900 Subject: [PATCH 011/205] perf(webflux): optimize session broadcasting with Flux.fromIterable (#109) Replace Flux.fromStream(sessions.values().stream()) with more efficient Flux.fromIterable(sessions.values()) to eliminate unnecessary stream conversion when broadcasting messages to active sessions Signed-off-by: jitokim --- .../server/transport/WebFluxSseServerTransportProvider.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java index 85a39a82f..af2ff06a3 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java @@ -171,10 +171,10 @@ public Mono notifyClients(String method, Map params) { logger.debug("Attempting to broadcast message to {} active sessions", sessions.size()); - return Flux.fromStream(sessions.values().stream()) + return Flux.fromIterable(sessions.values()) .flatMap(session -> session.sendNotification(method, params) - .doOnError(e -> logger.error("Failed to " + "send message to session " + "{}: {}", session.getId(), - e.getMessage())) + .doOnError( + e -> logger.error("Failed to send message to session {}: {}", session.getId(), e.getMessage())) .onErrorComplete()) .then(); } From cd624a7d5719db9648711c986be9bc9a149a34e4 Mon Sep 17 00:00:00 2001 From: Oleksandr Popov Date: Sun, 6 Apr 2025 10:29:29 +0200 Subject: [PATCH 012/205] fix: correct typos and improve documentation (#35) Signed-off-by: Christian Tzolov --- mcp/pom.xml | 2 +- .../io/modelcontextprotocol/client/McpAsyncClient.java | 10 +++++++++- .../client/transport/HttpClientSseClientTransport.java | 2 +- .../client/transport/StdioClientTransport.java | 2 +- .../io/modelcontextprotocol/spec/McpClientSession.java | 4 ++-- .../spec/McpClientSessionTests.java | 2 +- 6 files changed, 15 insertions(+), 7 deletions(-) diff --git a/mcp/pom.xml b/mcp/pom.xml index f6e93b39c..edb1c8f07 100644 --- a/mcp/pom.xml +++ b/mcp/pom.xml @@ -97,7 +97,7 @@ test - diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java index 379b47e23..ce49b0a5e 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java @@ -364,7 +364,7 @@ private Mono withInitializationCheck(String actionName, } // -------------------------- - // Basic Utilites + // Basic Utilities // -------------------------- /** @@ -751,6 +751,14 @@ private NotificationHandler asyncPromptsChangeNotificationHandler( // -------------------------- // Logging // -------------------------- + /** + * Create a notification handler for logging notifications from the server. This + * handler automatically distributes logging messages to all registered consumers. + * @param loggingConsumers List of consumers that will be notified when a logging + * message is received. Each consumer receives the logging message notification. + * @return A NotificationHandler that processes log notifications by distributing the + * message to all registered consumers + */ private NotificationHandler asyncLoggingNotificationHandler( List>> loggingConsumers) { diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java index 0b482533d..a5bdd43e2 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java @@ -376,7 +376,7 @@ public Mono closeGracefully() { } /** - * Unmarshals data to the specified type using the configured object mapper. + * Unmarshal data to the specified type using the configured object mapper. * @param data the data to unmarshal * @param typeRef the type reference for the target type * @param the target type diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java index f9a97849f..9d71cbb48 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java @@ -292,7 +292,7 @@ private void startInboundProcessing() { */ private void startOutboundProcessing() { this.handleOutbound(messages -> messages - // this bit is important since writes come from user threads and we + // this bit is important since writes come from user threads, and we // want to ensure that the actual writing happens on a dedicated thread .publishOn(outboundScheduler) .handle((message, s) -> { diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java index e29646e6a..719a78001 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java @@ -107,7 +107,7 @@ public interface NotificationHandler { public McpClientSession(Duration requestTimeout, McpClientTransport transport, Map> requestHandlers, Map notificationHandlers) { - Assert.notNull(requestTimeout, "The requstTimeout can not be null"); + Assert.notNull(requestTimeout, "The requestTimeout can not be null"); Assert.notNull(transport, "The transport can not be null"); Assert.notNull(requestHandlers, "The requestHandlers can not be null"); Assert.notNull(notificationHandlers, "The notificationHandlers can not be null"); @@ -127,7 +127,7 @@ public McpClientSession(Duration requestTimeout, McpClientTransport transport, logger.debug("Received Response: {}", response); var sink = pendingResponses.remove(response.id()); if (sink == null) { - logger.warn("Unexpected response for unkown id {}", response.id()); + logger.warn("Unexpected response for unknown id {}", response.id()); } else { sink.success(response); diff --git a/mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java b/mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java index 715d6651e..f72be43e0 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java @@ -61,7 +61,7 @@ void tearDown() { void testConstructorWithInvalidArguments() { assertThatThrownBy(() -> new McpClientSession(null, transport, Map.of(), Map.of())) .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("requstTimeout can not be null"); + .hasMessageContaining("The requestTimeout can not be null"); assertThatThrownBy(() -> new McpClientSession(TIMEOUT, null, Map.of(), Map.of())) .isInstanceOf(IllegalArgumentException.class) From 734153a445585cd452f041520583e6517f3674f3 Mon Sep 17 00:00:00 2001 From: jitokim Date: Sun, 6 Apr 2025 02:10:43 +0900 Subject: [PATCH 013/205] fix typo in WebFluxSseIntegrationTests Signed-off-by: jitokim --- .../WebFluxSseIntegrationTests.java | 28 +++++++++---------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java index 2be2f81f2..dbfad821f 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java @@ -52,8 +52,6 @@ public class WebFluxSseIntegrationTests { private static final int PORT = 8182; - // private static final String MESSAGE_ENDPOINT = "/mcp/message"; - private static final String CUSTOM_SSE_ENDPOINT = "/somePath/sse"; private static final String CUSTOM_MESSAGE_ENDPOINT = "/otherPath/mcp/message"; @@ -62,7 +60,7 @@ public class WebFluxSseIntegrationTests { private WebFluxSseServerTransportProvider mcpServerTransportProvider; - ConcurrentHashMap clientBulders = new ConcurrentHashMap<>(); + ConcurrentHashMap clientBuilders = new ConcurrentHashMap<>(); @BeforeEach public void before() { @@ -77,11 +75,11 @@ public void before() { ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); this.httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); - clientBulders.put("httpclient", + clientBuilders.put("httpclient", McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT) .sseEndpoint(CUSTOM_SSE_ENDPOINT) .build())); - clientBulders.put("webflux", + clientBuilders.put("webflux", McpClient .sync(WebFluxSseClientTransport.builder(WebClient.builder().baseUrl("http://localhost:" + PORT)) .sseEndpoint(CUSTOM_SSE_ENDPOINT) @@ -103,7 +101,7 @@ public void after() { @ValueSource(strings = { "httpclient", "webflux" }) void testCreateMessageWithoutSamplingCapabilities(String clientType) { - var clientBuilder = clientBulders.get(clientType); + var clientBuilder = clientBuilders.get(clientType); McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { @@ -134,7 +132,7 @@ void testCreateMessageWithoutSamplingCapabilities(String clientType) { void testCreateMessageSuccess(String clientType) throws InterruptedException { // Client - var clientBuilder = clientBulders.get(clientType); + var clientBuilder = clientBuilders.get(clientType); Function samplingHandler = request -> { assertThat(request.messages()).hasSize(1); @@ -203,7 +201,7 @@ void testCreateMessageSuccess(String clientType) throws InterruptedException { @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "httpclient", "webflux" }) void testRootsSuccess(String clientType) { - var clientBuilder = clientBulders.get(clientType); + var clientBuilder = clientBuilders.get(clientType); List roots = List.of(new Root("uri1://", "root1"), new Root("uri2://", "root2")); @@ -250,7 +248,7 @@ void testRootsSuccess(String clientType) { @ValueSource(strings = { "httpclient", "webflux" }) void testRootsWithoutCapability(String clientType) { - var clientBuilder = clientBulders.get(clientType); + var clientBuilder = clientBuilders.get(clientType); McpServerFeatures.SyncToolSpecification tool = new McpServerFeatures.SyncToolSpecification( new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { @@ -284,7 +282,7 @@ void testRootsWithoutCapability(String clientType) { @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "httpclient", "webflux" }) void testRootsNotifciationWithEmptyRootsList(String clientType) { - var clientBuilder = clientBulders.get(clientType); + var clientBuilder = clientBuilders.get(clientType); AtomicReference> rootsRef = new AtomicReference<>(); var mcpServer = McpServer.sync(mcpServerTransportProvider) @@ -311,7 +309,7 @@ void testRootsNotifciationWithEmptyRootsList(String clientType) { @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "httpclient", "webflux" }) void testRootsWithMultipleHandlers(String clientType) { - var clientBuilder = clientBulders.get(clientType); + var clientBuilder = clientBuilders.get(clientType); List roots = List.of(new Root("uri1://", "root1")); @@ -345,7 +343,7 @@ void testRootsWithMultipleHandlers(String clientType) { @ValueSource(strings = { "httpclient", "webflux" }) void testRootsServerCloseWithActiveSubscription(String clientType) { - var clientBuilder = clientBulders.get(clientType); + var clientBuilder = clientBuilders.get(clientType); List roots = List.of(new Root("uri1://", "root1")); @@ -390,7 +388,7 @@ void testRootsServerCloseWithActiveSubscription(String clientType) { @ValueSource(strings = { "httpclient", "webflux" }) void testToolCallSuccess(String clientType) { - var clientBuilder = clientBulders.get(clientType); + var clientBuilder = clientBuilders.get(clientType); var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); McpServerFeatures.SyncToolSpecification tool1 = new McpServerFeatures.SyncToolSpecification( @@ -430,7 +428,7 @@ void testToolCallSuccess(String clientType) { @ValueSource(strings = { "httpclient", "webflux" }) void testToolListChangeHandlingSuccess(String clientType) { - var clientBuilder = clientBulders.get(clientType); + var clientBuilder = clientBuilders.get(clientType); var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); McpServerFeatures.SyncToolSpecification tool1 = new McpServerFeatures.SyncToolSpecification( @@ -500,7 +498,7 @@ void testToolListChangeHandlingSuccess(String clientType) { @ValueSource(strings = { "httpclient", "webflux" }) void testInitialize(String clientType) { - var clientBuilder = clientBulders.get(clientType); + var clientBuilder = clientBuilders.get(clientType); var mcpServer = McpServer.sync(mcpServerTransportProvider).build(); From 0db4c0f70d28c72ec94750c289e001d54a5bace6 Mon Sep 17 00:00:00 2001 From: minguncle <57527858+minguncle@users.noreply.github.com> Date: Thu, 27 Mar 2025 15:53:30 +0800 Subject: [PATCH 014/205] feat(webmvc): Add support for custom context paths in WebMvcSseServerTransportProvider Adds the ability to specify a base URL for message endpoints in WebMvcSseServerTransportProvider, enabling proper handling of custom servlet context paths in Spring WebMVC applications. This ensures that clients receive the correct full endpoint URL when connecting through SSE. - Add messageBaseUrl field to WebMvcSseServerTransportProvider - Create new constructor that accepts messageBaseUrl parameter - Update endpoint event to include base URL in the message endpoint - Add TomcatTestUtil class to simplify test server creation - Add WebMvcSseCustomContextPathTests to verify custom context path functionality - Refactor WebMvcSseIntegrationTests to use the new TomcatTestUtil Co-authored-by: Christian Tzolov Signed-off-by: Christian Tzolov --- .../WebMvcSseServerTransportProvider.java | 51 ++++++--- .../server/TomcatTestUtil.java | 60 ++++++++++ .../WebMvcSseCustomContextPathTests.java | 105 ++++++++++++++++++ .../server/WebMvcSseIntegrationTests.java | 62 +++-------- mcp-test/pom.xml | 1 + 5 files changed, 216 insertions(+), 63 deletions(-) create mode 100644 mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/TomcatTestUtil.java create mode 100644 mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomContextPathTests.java diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java index 65416b256..f6dbd4779 100644 --- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java @@ -91,6 +91,8 @@ public class WebMvcSseServerTransportProvider implements McpServerTransportProvi private final String sseEndpoint; + private final String messageBaseUrl; + private final RouterFunction routerFunction; private McpServerSession.Factory sessionFactory; @@ -105,6 +107,20 @@ public class WebMvcSseServerTransportProvider implements McpServerTransportProvi */ private volatile boolean isClosing = false; + /** + * Constructs a new WebMvcSseServerTransportProvider instance with the default SSE + * endpoint. + * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + * of messages. + * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC + * messages via HTTP POST. This endpoint will be communicated to clients through the + * SSE connection's initial endpoint event. + * @throws IllegalArgumentException if either objectMapper or messageEndpoint is null + */ + public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint) { + this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT); + } + /** * Constructs a new WebMvcSseServerTransportProvider instance. * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization @@ -116,11 +132,30 @@ public class WebMvcSseServerTransportProvider implements McpServerTransportProvi * @throws IllegalArgumentException if any parameter is null */ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) { + this(objectMapper, "", messageEndpoint, sseEndpoint); + } + + /** + * Constructs a new WebMvcSseServerTransportProvider instance. + * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + * of messages. + * @param messageBaseUrl The base URL for the message endpoint, used to construct the + * full endpoint URL for clients. + * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC + * messages via HTTP POST. This endpoint will be communicated to clients through the + * SSE connection's initial endpoint event. + * @param sseEndpoint The endpoint URI where clients establish their SSE connections. + * @throws IllegalArgumentException if any parameter is null + */ + public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messageBaseUrl, String messageEndpoint, + String sseEndpoint) { Assert.notNull(objectMapper, "ObjectMapper must not be null"); + Assert.notNull(messageBaseUrl, "Message base URL must not be null"); Assert.notNull(messageEndpoint, "Message endpoint must not be null"); Assert.notNull(sseEndpoint, "SSE endpoint must not be null"); this.objectMapper = objectMapper; + this.messageBaseUrl = messageBaseUrl; this.messageEndpoint = messageEndpoint; this.sseEndpoint = sseEndpoint; this.routerFunction = RouterFunctions.route() @@ -129,20 +164,6 @@ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messag .build(); } - /** - * Constructs a new WebMvcSseServerTransportProvider instance with the default SSE - * endpoint. - * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization - * of messages. - * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC - * messages via HTTP POST. This endpoint will be communicated to clients through the - * SSE connection's initial endpoint event. - * @throws IllegalArgumentException if either objectMapper or messageEndpoint is null - */ - public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint) { - this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT); - } - @Override public void setSessionFactory(McpServerSession.Factory sessionFactory) { this.sessionFactory = sessionFactory; @@ -248,7 +269,7 @@ private ServerResponse handleSseConnection(ServerRequest request) { try { sseBuilder.id(sessionId) .event(ENDPOINT_EVENT_TYPE) - .data(messageEndpoint + "?sessionId=" + sessionId); + .data(this.messageBaseUrl + this.messageEndpoint + "?sessionId=" + sessionId); } catch (Exception e) { logger.error("Failed to send initial endpoint event: {}", e.getMessage()); diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/TomcatTestUtil.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/TomcatTestUtil.java new file mode 100644 index 000000000..fcd7fb4dc --- /dev/null +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/TomcatTestUtil.java @@ -0,0 +1,60 @@ +/* +* Copyright 2025 - 2025 the original author or authors. +*/ +package io.modelcontextprotocol.server; + +import org.apache.catalina.Context; +import org.apache.catalina.startup.Tomcat; + +import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; +import org.springframework.web.servlet.DispatcherServlet; + +/** + * @author Christian Tzolov + */ +public class TomcatTestUtil { + + public record TomcatServer(Tomcat tomcat, AnnotationConfigWebApplicationContext appContext) { + } + + public TomcatServer createTomcatServer(String contextPath, int port, Class componentClass) { + + // Set up Tomcat first + var tomcat = new Tomcat(); + tomcat.setPort(port); + + // Set Tomcat base directory to java.io.tmpdir to avoid permission issues + String baseDir = System.getProperty("java.io.tmpdir"); + tomcat.setBaseDir(baseDir); + + // Use the same directory for document base + Context context = tomcat.addContext(contextPath, baseDir); + + // Create and configure Spring WebMvc context + var appContext = new AnnotationConfigWebApplicationContext(); + appContext.register(componentClass); + appContext.setServletContext(context.getServletContext()); + appContext.refresh(); + + // Create DispatcherServlet with our Spring context + DispatcherServlet dispatcherServlet = new DispatcherServlet(appContext); + + // Add servlet to Tomcat and get the wrapper + var wrapper = Tomcat.addServlet(context, "dispatcherServlet", dispatcherServlet); + wrapper.setLoadOnStartup(1); + wrapper.setAsyncSupported(true); + context.addServletMappingDecoded("/*", "dispatcherServlet"); + + try { + // Configure and start the connector with async support + var connector = tomcat.getConnector(); + connector.setAsyncTimeout(3000); // 3 seconds timeout for async requests + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + + return new TomcatServer(tomcat, appContext); + } + +} diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomContextPathTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomContextPathTests.java new file mode 100644 index 000000000..0e81104b9 --- /dev/null +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomContextPathTests.java @@ -0,0 +1,105 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ +package io.modelcontextprotocol.server; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.server.transport.WebMvcSseServerTransportProvider; +import io.modelcontextprotocol.spec.McpSchema; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.web.servlet.config.annotation.EnableWebMvc; +import org.springframework.web.servlet.function.RouterFunction; +import org.springframework.web.servlet.function.ServerResponse; + +import static org.assertj.core.api.Assertions.assertThat; + +public class WebMvcSseCustomContextPathTests { + + private static final String CUSTOM_CONTEXT_PATH = "/app/1"; + + private static final int PORT = 8183; + + private static final String MESSAGE_ENDPOINT = "/mcp/message"; + + private WebMvcSseServerTransportProvider mcpServerTransportProvider; + + McpClient.SyncSpec clientBuilder; + + private TomcatTestUtil.TomcatServer tomcatServer; + + @BeforeEach + public void before() { + + tomcatServer = new TomcatTestUtil().createTomcatServer(CUSTOM_CONTEXT_PATH, PORT, TestConfig.class); + + try { + tomcatServer.tomcat().start(); + assertThat(tomcatServer.tomcat().getServer().getState() == LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + + var clientTransport = HttpClientSseClientTransport.builder("http://localhost:" + PORT) + .sseEndpoint(CUSTOM_CONTEXT_PATH + WebMvcSseServerTransportProvider.DEFAULT_SSE_ENDPOINT) + .build(); + + clientBuilder = McpClient.sync(clientTransport); + + mcpServerTransportProvider = tomcatServer.appContext().getBean(WebMvcSseServerTransportProvider.class); + } + + @AfterEach + public void after() { + if (mcpServerTransportProvider != null) { + mcpServerTransportProvider.closeGracefully().block(); + } + if (tomcatServer.appContext() != null) { + tomcatServer.appContext().close(); + } + if (tomcatServer.tomcat() != null) { + try { + tomcatServer.tomcat().stop(); + tomcatServer.tomcat().destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + + @Test + void testCustomContextPath() { + McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").build(); + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")).build(); + assertThat(client.initialize()).isNotNull(); + } + + @Configuration + @EnableWebMvc + static class TestConfig { + + @Bean + public WebMvcSseServerTransportProvider webMvcSseServerTransportProvider() { + + return new WebMvcSseServerTransportProvider(new ObjectMapper(), CUSTOM_CONTEXT_PATH, MESSAGE_ENDPOINT, + WebMvcSseServerTransportProvider.DEFAULT_SSE_ENDPOINT); + } + + @Bean + public RouterFunction routerFunction(WebMvcSseServerTransportProvider transportProvider) { + return transportProvider.getRouterFunction(); + } + + } + +} diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java index 3ff755ca9..f9190fd70 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java @@ -25,10 +25,8 @@ import io.modelcontextprotocol.spec.McpSchema.Root; import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; import io.modelcontextprotocol.spec.McpSchema.Tool; -import org.apache.catalina.Context; import org.apache.catalina.LifecycleException; import org.apache.catalina.LifecycleState; -import org.apache.catalina.startup.Tomcat; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -38,15 +36,12 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.web.client.RestClient; -import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; -import org.springframework.web.servlet.DispatcherServlet; import org.springframework.web.servlet.config.annotation.EnableWebMvc; import org.springframework.web.servlet.function.RouterFunction; import org.springframework.web.servlet.function.ServerResponse; import static org.assertj.core.api.Assertions.assertThat; import static org.awaitility.Awaitility.await; -import static org.junit.Assert.assertThat; import static org.mockito.Mockito.mock; public class WebMvcSseIntegrationTests { @@ -75,55 +70,26 @@ public RouterFunction routerFunction(WebMvcSseServerTransportPro } - private Tomcat tomcat; - - private AnnotationConfigWebApplicationContext appContext; + private TomcatTestUtil.TomcatServer tomcatServer; @BeforeEach public void before() { - // Set up Tomcat first - tomcat = new Tomcat(); - tomcat.setPort(PORT); - - // Set Tomcat base directory to java.io.tmpdir to avoid permission issues - String baseDir = System.getProperty("java.io.tmpdir"); - tomcat.setBaseDir(baseDir); - - // Use the same directory for document base - Context context = tomcat.addContext("", baseDir); - - // Create and configure Spring WebMvc context - appContext = new AnnotationConfigWebApplicationContext(); - appContext.register(TestConfig.class); - appContext.setServletContext(context.getServletContext()); - appContext.refresh(); - - // Get the transport from Spring context - mcpServerTransportProvider = appContext.getBean(WebMvcSseServerTransportProvider.class); - - // Create DispatcherServlet with our Spring context - DispatcherServlet dispatcherServlet = new DispatcherServlet(appContext); - // dispatcherServlet.setThrowExceptionIfNoHandlerFound(true); - - // Add servlet to Tomcat and get the wrapper - var wrapper = Tomcat.addServlet(context, "dispatcherServlet", dispatcherServlet); - wrapper.setLoadOnStartup(1); - wrapper.setAsyncSupported(true); - context.addServletMappingDecoded("/*", "dispatcherServlet"); + tomcatServer = new TomcatTestUtil().createTomcatServer("", PORT, TestConfig.class); try { - // Configure and start the connector with async support - var connector = tomcat.getConnector(); - connector.setAsyncTimeout(3000); // 3 seconds timeout for async requests - tomcat.start(); - assertThat(tomcat.getServer().getState() == LifecycleState.STARTED); + tomcatServer.tomcat().start(); + assertThat(tomcatServer.tomcat().getServer().getState() == LifecycleState.STARTED); } catch (Exception e) { throw new RuntimeException("Failed to start Tomcat", e); } - this.clientBuilder = McpClient.sync(new HttpClientSseClientTransport("http://localhost:" + PORT)); + clientBuilder = McpClient.sync(new HttpClientSseClientTransport("http://localhost:" + PORT)); + + // Get the transport from Spring context + mcpServerTransportProvider = tomcatServer.appContext().getBean(WebMvcSseServerTransportProvider.class); + } @AfterEach @@ -131,13 +97,13 @@ public void after() { if (mcpServerTransportProvider != null) { mcpServerTransportProvider.closeGracefully().block(); } - if (appContext != null) { - appContext.close(); + if (tomcatServer.appContext() != null) { + tomcatServer.appContext().close(); } - if (tomcat != null) { + if (tomcatServer.tomcat() != null) { try { - tomcat.stop(); - tomcat.destroy(); + tomcatServer.tomcat().stop(); + tomcatServer.tomcat().destroy(); } catch (LifecycleException e) { throw new RuntimeException("Failed to stop Tomcat", e); diff --git a/mcp-test/pom.xml b/mcp-test/pom.xml index b995618af..95f5dc30a 100644 --- a/mcp-test/pom.xml +++ b/mcp-test/pom.xml @@ -80,6 +80,7 @@ logback-classic ${logback.version} + From 3fa415228757c78bdbcabdfeaf548b3b09882b2f Mon Sep 17 00:00:00 2001 From: zhangzhenhua Date: Wed, 2 Apr 2025 13:54:14 +0800 Subject: [PATCH 015/205] feat(webflux): Add base URL support to WebFluxSseServerTransportProvider (#102) Adds the ability to specify a base URL prefix for message endpoints in the WebFlux SSE server transport provider. This enhancement allows for proper URL construction when the server is running behind a proxy or in a context with a base path. - Add new constructor with baseUrl parameter - Add basePath() method to Builder class - Modify SSE endpoint event to include baseUrl prefix Signed-off-by: Christian Tzolov --- .../WebFluxSseServerTransportProvider.java | 75 ++++++++++++++----- 1 file changed, 58 insertions(+), 17 deletions(-) diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java index af2ff06a3..df8dd0211 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java @@ -82,8 +82,16 @@ public class WebFluxSseServerTransportProvider implements McpServerTransportProv */ public static final String DEFAULT_SSE_ENDPOINT = "/sse"; + public static final String DEFAULT_BASE_URL = ""; + private final ObjectMapper objectMapper; + /** + * Base URL for the message endpoint. This is used to construct the full URL for + * clients to send their JSON-RPC messages. + */ + private final String baseUrl; + private final String messageEndpoint; private final String sseEndpoint; @@ -102,6 +110,20 @@ public class WebFluxSseServerTransportProvider implements McpServerTransportProv */ private volatile boolean isClosing = false; + /** + * Constructs a new WebFlux SSE server transport provider instance with the default + * SSE endpoint. + * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + * of MCP messages. Must not be null. + * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC + * messages. This endpoint will be communicated to clients during SSE connection + * setup. Must not be null. + * @throws IllegalArgumentException if either parameter is null + */ + public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint) { + this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT); + } + /** * Constructs a new WebFlux SSE server transport provider instance. * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization @@ -112,11 +134,28 @@ public class WebFluxSseServerTransportProvider implements McpServerTransportProv * @throws IllegalArgumentException if either parameter is null */ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) { + this(objectMapper, DEFAULT_BASE_URL, messageEndpoint, sseEndpoint); + } + + /** + * Constructs a new WebFlux SSE server transport provider instance. + * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + * of MCP messages. Must not be null. + * @param baseUrl webflux messag base path + * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC + * messages. This endpoint will be communicated to clients during SSE connection + * setup. Must not be null. + * @throws IllegalArgumentException if either parameter is null + */ + public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, + String sseEndpoint) { Assert.notNull(objectMapper, "ObjectMapper must not be null"); + Assert.notNull(baseUrl, "Message base path must not be null"); Assert.notNull(messageEndpoint, "Message endpoint must not be null"); Assert.notNull(sseEndpoint, "SSE endpoint must not be null"); this.objectMapper = objectMapper; + this.baseUrl = baseUrl; this.messageEndpoint = messageEndpoint; this.sseEndpoint = sseEndpoint; this.routerFunction = RouterFunctions.route() @@ -125,20 +164,6 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messa .build(); } - /** - * Constructs a new WebFlux SSE server transport provider instance with the default - * SSE endpoint. - * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization - * of MCP messages. Must not be null. - * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC - * messages. This endpoint will be communicated to clients during SSE connection - * setup. Must not be null. - * @throws IllegalArgumentException if either parameter is null - */ - public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint) { - this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT); - } - @Override public void setSessionFactory(McpServerSession.Factory sessionFactory) { this.sessionFactory = sessionFactory; @@ -179,7 +204,8 @@ public Mono notifyClients(String method, Map params) { .then(); } - // FIXME: This javadoc makes claims about using isClosing flag but it's not actually + // FIXME: This javadoc makes claims about using isClosing flag but it's not + // actually // doing that. /** * Initiates a graceful shutdown of all the sessions. This method ensures all active @@ -245,7 +271,7 @@ private Mono handleSseConnection(ServerRequest request) { logger.debug("Sending initial endpoint event to session: {}", sessionId); sink.next(ServerSentEvent.builder() .event(ENDPOINT_EVENT_TYPE) - .data(messageEndpoint + "?sessionId=" + sessionId) + .data(this.baseUrl + this.messageEndpoint + "?sessionId=" + sessionId) .build()); sink.onCancel(() -> { logger.debug("Session {} cancelled", sessionId); @@ -360,6 +386,8 @@ public static class Builder { private ObjectMapper objectMapper; + private String baseUrl = DEFAULT_BASE_URL; + private String messageEndpoint; private String sseEndpoint = DEFAULT_SSE_ENDPOINT; @@ -377,6 +405,19 @@ public Builder objectMapper(ObjectMapper objectMapper) { return this; } + /** + * Sets the project basePath as endpoint prefix where clients should send their + * JSON-RPC messages + * @param baseUrl the message basePath . Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if basePath is null + */ + public Builder basePath(String baseUrl) { + Assert.notNull(baseUrl, "basePath must not be null"); + this.baseUrl = baseUrl; + return this; + } + /** * Sets the endpoint URI where clients should send their JSON-RPC messages. * @param messageEndpoint The message endpoint URI. Must not be null. @@ -411,7 +452,7 @@ public WebFluxSseServerTransportProvider build() { Assert.notNull(objectMapper, "ObjectMapper must be set"); Assert.notNull(messageEndpoint, "Message endpoint must be set"); - return new WebFluxSseServerTransportProvider(objectMapper, messageEndpoint, sseEndpoint); + return new WebFluxSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint); } } From b21cfab10ec9d51c8f57541767337dfd790a43b2 Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Sun, 6 Apr 2025 16:33:25 +0200 Subject: [PATCH 016/205] refactor(webmvc): Rename messageBaseUrl to baseUrl for consistency Signed-off-by: Christian Tzolov --- .../WebMvcSseServerTransportProvider.java | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java index f6dbd4779..fa2e357f9 100644 --- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java @@ -91,7 +91,7 @@ public class WebMvcSseServerTransportProvider implements McpServerTransportProvi private final String sseEndpoint; - private final String messageBaseUrl; + private final String baseUrl; private final RouterFunction routerFunction; @@ -139,23 +139,23 @@ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messag * Constructs a new WebMvcSseServerTransportProvider instance. * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization * of messages. - * @param messageBaseUrl The base URL for the message endpoint, used to construct the - * full endpoint URL for clients. + * @param baseUrl The base URL for the message endpoint, used to construct the full + * endpoint URL for clients. * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC * messages via HTTP POST. This endpoint will be communicated to clients through the * SSE connection's initial endpoint event. * @param sseEndpoint The endpoint URI where clients establish their SSE connections. * @throws IllegalArgumentException if any parameter is null */ - public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messageBaseUrl, String messageEndpoint, + public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, String sseEndpoint) { Assert.notNull(objectMapper, "ObjectMapper must not be null"); - Assert.notNull(messageBaseUrl, "Message base URL must not be null"); + Assert.notNull(baseUrl, "Message base URL must not be null"); Assert.notNull(messageEndpoint, "Message endpoint must not be null"); Assert.notNull(sseEndpoint, "SSE endpoint must not be null"); this.objectMapper = objectMapper; - this.messageBaseUrl = messageBaseUrl; + this.baseUrl = baseUrl; this.messageEndpoint = messageEndpoint; this.sseEndpoint = sseEndpoint; this.routerFunction = RouterFunctions.route() @@ -269,7 +269,7 @@ private ServerResponse handleSseConnection(ServerRequest request) { try { sseBuilder.id(sessionId) .event(ENDPOINT_EVENT_TYPE) - .data(this.messageBaseUrl + this.messageEndpoint + "?sessionId=" + sessionId); + .data(this.baseUrl + this.messageEndpoint + "?sessionId=" + sessionId); } catch (Exception e) { logger.error("Failed to send initial endpoint event: {}", e.getMessage()); From 8fc72aed88616cfe4ba4fe8adae038b32fcc9f8b Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Sun, 6 Apr 2025 18:41:25 +0200 Subject: [PATCH 017/205] feat(mcp): Add support for custom context paths in HTTP Servlet SSE server transport Enhance HttpServletSseServerTransportProvider to support deployment under non-root context paths by: - Adding baseUrl field and DEFAULT_BASE_URL constant - Creating new constructor that accepts a baseUrl parameter - Extending Builder with baseUrl configuration method - Prepending baseUrl to message endpoint in SSE events - Add HttpServletSseServerCustomContextPathTests to verify custom context path functionality - Extract common Tomcat server setup code to TomcatTestUtil for test reuse Related to #79 Signed-off-by: Christian Tzolov --- ...HttpServletSseServerTransportProvider.java | 37 +++++++- ...ervletSseServerCustomContextPathTests.java | 86 +++++++++++++++++++ ...rverTransportProviderIntegrationTests.java | 21 +---- .../server/transport/TomcatTestUtil.java | 45 ++++++++++ 4 files changed, 167 insertions(+), 22 deletions(-) create mode 100644 mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerCustomContextPathTests.java create mode 100644 mcp/src/test/java/io/modelcontextprotocol/server/transport/TomcatTestUtil.java diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java index a64b4a353..e52fc88b7 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java @@ -80,9 +80,14 @@ public class HttpServletSseServerTransportProvider extends HttpServlet implement /** Event type for endpoint information */ public static final String ENDPOINT_EVENT_TYPE = "endpoint"; + public static final String DEFAULT_BASE_URL = ""; + /** JSON object mapper for serialization/deserialization */ private final ObjectMapper objectMapper; + /** Base URL for the server transport */ + private final String baseUrl; + /** The endpoint path for handling client messages */ private final String messageEndpoint; @@ -108,7 +113,22 @@ public class HttpServletSseServerTransportProvider extends HttpServlet implement */ public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) { + this(objectMapper, DEFAULT_BASE_URL, messageEndpoint, sseEndpoint); + } + + /** + * Creates a new HttpServletSseServerTransportProvider instance with a custom SSE + * endpoint. + * @param objectMapper The JSON object mapper to use for message + * serialization/deserialization + * @param baseUrl The base URL for the server transport + * @param messageEndpoint The endpoint path where clients will send their messages + * @param sseEndpoint The endpoint path where clients will establish SSE connections + */ + public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, + String sseEndpoint) { this.objectMapper = objectMapper; + this.baseUrl = baseUrl; this.messageEndpoint = messageEndpoint; this.sseEndpoint = sseEndpoint; } @@ -203,7 +223,7 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response) this.sessions.put(sessionId, session); // Send initial endpoint event - this.sendEvent(writer, ENDPOINT_EVENT_TYPE, messageEndpoint + "?sessionId=" + sessionId); + this.sendEvent(writer, ENDPOINT_EVENT_TYPE, this.baseUrl + this.messageEndpoint + "?sessionId=" + sessionId); } /** @@ -449,6 +469,8 @@ public static class Builder { private ObjectMapper objectMapper = new ObjectMapper(); + private String baseUrl = DEFAULT_BASE_URL; + private String messageEndpoint; private String sseEndpoint = DEFAULT_SSE_ENDPOINT; @@ -464,6 +486,17 @@ public Builder objectMapper(ObjectMapper objectMapper) { return this; } + /** + * Sets the base URL for the server transport. + * @param baseUrl The base URL to use + * @return This builder instance for method chaining + */ + public Builder baseUrl(String baseUrl) { + Assert.notNull(baseUrl, "Base URL must not be null"); + this.baseUrl = baseUrl; + return this; + } + /** * Sets the endpoint path where clients will send their messages. * @param messageEndpoint The message endpoint path @@ -502,7 +535,7 @@ public HttpServletSseServerTransportProvider build() { if (messageEndpoint == null) { throw new IllegalStateException("MessageEndpoint must be set"); } - return new HttpServletSseServerTransportProvider(objectMapper, messageEndpoint, sseEndpoint); + return new HttpServletSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint); } } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerCustomContextPathTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerCustomContextPathTests.java new file mode 100644 index 000000000..1254e2ad8 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerCustomContextPathTests.java @@ -0,0 +1,86 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ +package io.modelcontextprotocol.server.transport; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.spec.McpSchema; +import org.apache.catalina.Context; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.apache.catalina.startup.Tomcat; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +public class HttpServletSseServerCustomContextPathTests { + + private static final int PORT = 8195; + + private static final String CUSTOM_CONTEXT_PATH = "/api/v1"; + + private static final String CUSTOM_SSE_ENDPOINT = "/somePath/sse"; + + private static final String CUSTOM_MESSAGE_ENDPOINT = "/otherPath/mcp/message"; + + private HttpServletSseServerTransportProvider mcpServerTransportProvider; + + McpClient.SyncSpec clientBuilder; + + private Tomcat tomcat; + + @BeforeEach + public void before() { + + // Create and configure the transport provider + mcpServerTransportProvider = HttpServletSseServerTransportProvider.builder() + .objectMapper(new ObjectMapper()) + .baseUrl(CUSTOM_CONTEXT_PATH) + .messageEndpoint(CUSTOM_MESSAGE_ENDPOINT) + .sseEndpoint(CUSTOM_SSE_ENDPOINT) + .build(); + + tomcat = TomcatTestUtil.createTomcatServer(CUSTOM_CONTEXT_PATH, PORT, mcpServerTransportProvider); + + try { + tomcat.start(); + assertThat(tomcat.getServer().getState() == LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + + this.clientBuilder = McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT) + .sseEndpoint(CUSTOM_CONTEXT_PATH + CUSTOM_SSE_ENDPOINT) + .build()); + } + + @AfterEach + public void after() { + if (mcpServerTransportProvider != null) { + mcpServerTransportProvider.closeGracefully().block(); + } + if (tomcat != null) { + try { + tomcat.stop(); + tomcat.destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + + @Test + void testCustomContextPath() { + McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").build(); + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")).build(); + assertThat(client.initialize()).isNotNull(); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java index 1cd395e74..b04940c79 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java @@ -26,7 +26,6 @@ import io.modelcontextprotocol.spec.McpSchema.Root; import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; import io.modelcontextprotocol.spec.McpSchema.Tool; -import org.apache.catalina.Context; import org.apache.catalina.LifecycleException; import org.apache.catalina.LifecycleState; import org.apache.catalina.startup.Tomcat; @@ -59,14 +58,6 @@ public class HttpServletSseServerTransportProviderIntegrationTests { @BeforeEach public void before() { - tomcat = new Tomcat(); - tomcat.setPort(PORT); - - String baseDir = System.getProperty("java.io.tmpdir"); - tomcat.setBaseDir(baseDir); - - Context context = tomcat.addContext("", baseDir); - // Create and configure the transport provider mcpServerTransportProvider = HttpServletSseServerTransportProvider.builder() .objectMapper(new ObjectMapper()) @@ -74,18 +65,8 @@ public void before() { .sseEndpoint(CUSTOM_SSE_ENDPOINT) .build(); - // Add transport servlet to Tomcat - org.apache.catalina.Wrapper wrapper = context.createWrapper(); - wrapper.setName("mcpServlet"); - wrapper.setServlet(mcpServerTransportProvider); - wrapper.setLoadOnStartup(1); - wrapper.setAsyncSupported(true); - context.addChild(wrapper); - context.addServletMappingDecoded("/*", "mcpServlet"); - + tomcat = TomcatTestUtil.createTomcatServer("", PORT, mcpServerTransportProvider); try { - var connector = tomcat.getConnector(); - connector.setAsyncTimeout(3000); tomcat.start(); assertThat(tomcat.getServer().getState() == LifecycleState.STARTED); } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/TomcatTestUtil.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/TomcatTestUtil.java new file mode 100644 index 000000000..6f922dfa6 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/TomcatTestUtil.java @@ -0,0 +1,45 @@ +/* +* Copyright 2025 - 2025 the original author or authors. +*/ +package io.modelcontextprotocol.server.transport; + +import com.fasterxml.jackson.databind.ObjectMapper; +import jakarta.servlet.Servlet; +import org.apache.catalina.Context; +import org.apache.catalina.LifecycleState; +import org.apache.catalina.startup.Tomcat; + +import static org.junit.Assert.assertThat; + +/** + * @author Christian Tzolov + */ +public class TomcatTestUtil { + + public static Tomcat createTomcatServer(String contextPath, int port, Servlet servlet) { + + var tomcat = new Tomcat(); + tomcat.setPort(port); + + String baseDir = System.getProperty("java.io.tmpdir"); + tomcat.setBaseDir(baseDir); + + // Context context = tomcat.addContext("", baseDir); + Context context = tomcat.addContext(contextPath, baseDir); + + // Add transport servlet to Tomcat + org.apache.catalina.Wrapper wrapper = context.createWrapper(); + wrapper.setName("mcpServlet"); + wrapper.setServlet(servlet); + wrapper.setLoadOnStartup(1); + wrapper.setAsyncSupported(true); + context.addChild(wrapper); + context.addServletMappingDecoded("/*", "mcpServlet"); + + var connector = tomcat.getConnector(); + connector.setAsyncTimeout(3000); + + return tomcat; + } + +} From 13c4474b3ea00e75be47b653d315ba9de7125cb3 Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Wed, 9 Apr 2025 16:05:07 +0200 Subject: [PATCH 018/205] Change the URLs used to test blocking rest calls Signed-off-by: Christian Tzolov --- .../io/modelcontextprotocol/WebFluxSseIntegrationTests.java | 6 +++--- .../server/WebMvcSseIntegrationTests.java | 6 +++--- ...tpServletSseServerTransportProviderIntegrationTests.java | 6 +++--- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java index dbfad821f..ac487b6f5 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java @@ -396,7 +396,7 @@ void testToolCallSuccess(String clientType) { // perform a blocking call to a remote service String response = RestClient.create() .get() - .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") + .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md") .retrieve() .body(String.class); assertThat(response).isNotBlank(); @@ -436,7 +436,7 @@ void testToolListChangeHandlingSuccess(String clientType) { // perform a blocking call to a remote service String response = RestClient.create() .get() - .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") + .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md") .retrieve() .body(String.class); assertThat(response).isNotBlank(); @@ -453,7 +453,7 @@ void testToolListChangeHandlingSuccess(String clientType) { // perform a blocking call to a remote service String response = RestClient.create() .get() - .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") + .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md") .retrieve() .body(String.class); assertThat(response).isNotBlank(); diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java index f9190fd70..420f4b987 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java @@ -388,7 +388,7 @@ void testToolCallSuccess() { new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { // perform a blocking call to a remote service String response = RestClient.create() - .get() + .get()https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") .retrieve() .body(String.class); @@ -424,7 +424,7 @@ void testToolListChangeHandlingSuccess() { McpServerFeatures.SyncToolSpecification tool1 = new McpServerFeatures.SyncToolSpecification( new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { // perform a blocking call to a remote service - String response = RestClient.create() + String https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md .get() .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") .retrieve() @@ -441,7 +441,7 @@ void testToolListChangeHandlingSuccess() { AtomicReference> rootsRef = new AtomicReference<>(); var mcpClient = clientBuilder.toolsChangeConsumer(toolsUpdate -> { // perform a blocking call to a remote service - String response = RestClient.create() + String https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md .get() .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") .retrieve() diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java index b04940c79..e34baf9d6 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java @@ -374,7 +374,7 @@ void testToolCallSuccess() { // perform a blocking call to a remote service String response = RestClient.create() .get() - .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") + .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md") .retrieve() .body(String.class); assertThat(response).isNotBlank(); @@ -411,7 +411,7 @@ void testToolListChangeHandlingSuccess() { // perform a blocking call to a remote service String response = RestClient.create() .get() - .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") + .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md") .retrieve() .body(String.class); assertThat(response).isNotBlank(); @@ -428,7 +428,7 @@ void testToolListChangeHandlingSuccess() { // perform a blocking call to a remote service String response = RestClient.create() .get() - .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") + .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md") .retrieve() .body(String.class); assertThat(response).isNotBlank(); From fbea833384c097a46927624f1f7cbb9562c15e74 Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Wed, 9 Apr 2025 16:17:16 +0200 Subject: [PATCH 019/205] Fix compilation issue introduced by the previous commit Signed-off-by: Christian Tzolov --- .../server/WebMvcSseIntegrationTests.java | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java index 420f4b987..c203e3bd5 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java @@ -388,8 +388,8 @@ void testToolCallSuccess() { new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { // perform a blocking call to a remote service String response = RestClient.create() - .get()https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md - .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") + .get() + .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md") .retrieve() .body(String.class); assertThat(response).isNotBlank(); @@ -424,9 +424,9 @@ void testToolListChangeHandlingSuccess() { McpServerFeatures.SyncToolSpecification tool1 = new McpServerFeatures.SyncToolSpecification( new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { // perform a blocking call to a remote service - String https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md + String response = RestClient.create() .get() - .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") + .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md") .retrieve() .body(String.class); assertThat(response).isNotBlank(); @@ -441,9 +441,9 @@ void testToolListChangeHandlingSuccess() { AtomicReference> rootsRef = new AtomicReference<>(); var mcpClient = clientBuilder.toolsChangeConsumer(toolsUpdate -> { // perform a blocking call to a remote service - String https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md + String response = RestClient.create() .get() - .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") + .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md") .retrieve() .body(String.class); assertThat(response).isNotBlank(); From fab434c088e7e90ad4cbbedd55b28c553536c7de Mon Sep 17 00:00:00 2001 From: "a.darafeyeu" Date: Mon, 7 Apr 2025 14:30:04 +0200 Subject: [PATCH 020/205] refactor(client): enhance HttpClientSseClientTransport with flexible customization API (#117) - Add builder customizeClient() and customizeRequest() methods - Enable HTTP client and request configuration through consumer-based customization - Deprecate direct constructors in favor of the more flexible builder approach - Add test coverage for customization capabilities Co-authored-by: Christian Tzolov Signed-off-by: Christian Tzolov --- .../server/WebMvcSseIntegrationTests.java | 12 +-- .../HttpClientSseClientTransport.java | 92 ++++++++++++++++-- .../client/HttpSseMcpAsyncClientTests.java | 4 +- .../client/HttpSseMcpSyncClientTests.java | 2 +- .../HttpClientSseClientTransportTests.java | 97 ++++++++++++++++++- 5 files changed, 185 insertions(+), 22 deletions(-) diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java index c203e3bd5..d5c9f90ff 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java @@ -44,7 +44,7 @@ import static org.awaitility.Awaitility.await; import static org.mockito.Mockito.mock; -public class WebMvcSseIntegrationTests { +class WebMvcSseIntegrationTests { private static final int PORT = 8183; @@ -79,13 +79,13 @@ public void before() { try { tomcatServer.tomcat().start(); - assertThat(tomcatServer.tomcat().getServer().getState() == LifecycleState.STARTED); + assertThat(tomcatServer.tomcat().getServer().getState()).isEqualTo(LifecycleState.STARTED); } catch (Exception e) { throw new RuntimeException("Failed to start Tomcat", e); } - clientBuilder = McpClient.sync(new HttpClientSseClientTransport("http://localhost:" + PORT)); + clientBuilder = McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT).build()); // Get the transport from Spring context mcpServerTransportProvider = tomcatServer.appContext().getBean(WebMvcSseServerTransportProvider.class); @@ -200,8 +200,7 @@ void testCreateMessageSuccess() throws InterruptedException { CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); + assertThat(response).isNotNull().isEqualTo(callResponse); mcpClient.close(); mcpServer.close(); @@ -410,8 +409,7 @@ void testToolCallSuccess() { CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); + assertThat(response).isNotNull().isEqualTo(callResponse); mcpClient.close(); mcpServer.close(); diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java index a5bdd43e2..632d3844a 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java @@ -13,6 +13,7 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; import java.util.function.Function; import com.fasterxml.jackson.core.type.TypeReference; @@ -103,7 +104,10 @@ public class HttpClientSseClientTransport implements McpClientTransport { /** * Creates a new transport instance with default HTTP client and object mapper. * @param baseUri the base URI of the MCP server + * @deprecated Use {@link HttpClientSseClientTransport#builder(String)} instead. This + * constructor will be removed in future versions. */ + @Deprecated(forRemoval = true) public HttpClientSseClientTransport(String baseUri) { this(HttpClient.newBuilder(), baseUri, new ObjectMapper()); } @@ -114,7 +118,10 @@ public HttpClientSseClientTransport(String baseUri) { * @param baseUri the base URI of the MCP server * @param objectMapper the object mapper for JSON serialization/deserialization * @throws IllegalArgumentException if objectMapper or clientBuilder is null + * @deprecated Use {@link HttpClientSseClientTransport#builder(String)} instead. This + * constructor will be removed in future versions. */ + @Deprecated(forRemoval = true) public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, String baseUri, ObjectMapper objectMapper) { this(clientBuilder, baseUri, DEFAULT_SSE_ENDPOINT, objectMapper); } @@ -126,7 +133,10 @@ public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, String bas * @param sseEndpoint the SSE endpoint path * @param objectMapper the object mapper for JSON serialization/deserialization * @throws IllegalArgumentException if objectMapper or clientBuilder is null + * @deprecated Use {@link HttpClientSseClientTransport#builder(String)} instead. This + * constructor will be removed in future versions. */ + @Deprecated(forRemoval = true) public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, String baseUri, String sseEndpoint, ObjectMapper objectMapper) { this(clientBuilder, HttpRequest.newBuilder(), baseUri, sseEndpoint, objectMapper); @@ -141,18 +151,37 @@ public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, String bas * @param sseEndpoint the SSE endpoint path * @param objectMapper the object mapper for JSON serialization/deserialization * @throws IllegalArgumentException if objectMapper, clientBuilder, or headers is null + * @deprecated Use {@link HttpClientSseClientTransport#builder(String)} instead. This + * constructor will be removed in future versions. */ + @Deprecated(forRemoval = true) public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, HttpRequest.Builder requestBuilder, String baseUri, String sseEndpoint, ObjectMapper objectMapper) { + this(clientBuilder.connectTimeout(Duration.ofSeconds(10)).build(), requestBuilder, baseUri, sseEndpoint, + objectMapper); + } + + /** + * Creates a new transport instance with custom HTTP client builder, object mapper, + * and headers. + * @param httpClient the HTTP client to use + * @param requestBuilder the HTTP request builder to use + * @param baseUri the base URI of the MCP server + * @param sseEndpoint the SSE endpoint path + * @param objectMapper the object mapper for JSON serialization/deserialization + * @throws IllegalArgumentException if objectMapper, clientBuilder, or headers is null + */ + HttpClientSseClientTransport(HttpClient httpClient, HttpRequest.Builder requestBuilder, String baseUri, + String sseEndpoint, ObjectMapper objectMapper) { Assert.notNull(objectMapper, "ObjectMapper must not be null"); Assert.hasText(baseUri, "baseUri must not be empty"); Assert.hasText(sseEndpoint, "sseEndpoint must not be empty"); - Assert.notNull(clientBuilder, "clientBuilder must not be null"); + Assert.notNull(httpClient, "httpClient must not be null"); Assert.notNull(requestBuilder, "requestBuilder must not be null"); this.baseUri = baseUri; this.sseEndpoint = sseEndpoint; this.objectMapper = objectMapper; - this.httpClient = clientBuilder.connectTimeout(Duration.ofSeconds(10)).build(); + this.httpClient = httpClient; this.requestBuilder = requestBuilder; this.sseClient = new FlowSseClient(this.httpClient, requestBuilder); @@ -164,7 +193,7 @@ public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, HttpReques * @return a new builder instance */ public static Builder builder(String baseUri) { - return new Builder(baseUri); + return new Builder().baseUri(baseUri); } /** @@ -172,25 +201,50 @@ public static Builder builder(String baseUri) { */ public static class Builder { - private final String baseUri; + private String baseUri; private String sseEndpoint = DEFAULT_SSE_ENDPOINT; - private HttpClient.Builder clientBuilder = HttpClient.newBuilder(); + private HttpClient.Builder clientBuilder = HttpClient.newBuilder() + .version(HttpClient.Version.HTTP_1_1) + .connectTimeout(Duration.ofSeconds(10)); private ObjectMapper objectMapper = new ObjectMapper(); - private HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(); + private HttpRequest.Builder requestBuilder = HttpRequest.newBuilder() + .header("Content-Type", "application/json"); + + /** + * Creates a new builder instance. + */ + Builder() { + // Default constructor + } /** * Creates a new builder with the specified base URI. * @param baseUri the base URI of the MCP server + * @deprecated Use {@link HttpClientSseClientTransport#builder(String)} instead. + * This constructor is deprecated and will be removed or made {@code protected} or + * {@code private} in a future release. */ + @Deprecated(forRemoval = true) public Builder(String baseUri) { Assert.hasText(baseUri, "baseUri must not be empty"); this.baseUri = baseUri; } + /** + * Sets the base URI. + * @param baseUri the base URI + * @return this builder + */ + Builder baseUri(String baseUri) { + Assert.hasText(baseUri, "baseUri must not be empty"); + this.baseUri = baseUri; + return this; + } + /** * Sets the SSE endpoint path. * @param sseEndpoint the SSE endpoint path @@ -213,6 +267,17 @@ public Builder clientBuilder(HttpClient.Builder clientBuilder) { return this; } + /** + * Customizes the HTTP client builder. + * @param clientCustomizer the consumer to customize the HTTP client builder + * @return this builder + */ + public Builder customizeClient(final Consumer clientCustomizer) { + Assert.notNull(clientCustomizer, "clientCustomizer must not be null"); + clientCustomizer.accept(clientBuilder); + return this; + } + /** * Sets the HTTP request builder. * @param requestBuilder the HTTP request builder @@ -224,6 +289,17 @@ public Builder requestBuilder(HttpRequest.Builder requestBuilder) { return this; } + /** + * Customizes the HTTP client builder. + * @param requestCustomizer the consumer to customize the HTTP request builder + * @return this builder + */ + public Builder customizeRequest(final Consumer requestCustomizer) { + Assert.notNull(requestCustomizer, "requestCustomizer must not be null"); + requestCustomizer.accept(requestBuilder); + return this; + } + /** * Sets the object mapper for JSON serialization/deserialization. * @param objectMapper the object mapper @@ -240,7 +316,8 @@ public Builder objectMapper(ObjectMapper objectMapper) { * @return a new transport instance */ public HttpClientSseClientTransport build() { - return new HttpClientSseClientTransport(clientBuilder, requestBuilder, baseUri, sseEndpoint, objectMapper); + return new HttpClientSseClientTransport(clientBuilder.build(), requestBuilder, baseUri, sseEndpoint, + objectMapper); } } @@ -336,7 +413,6 @@ public Mono sendMessage(JSONRPCMessage message) { try { String jsonText = this.objectMapper.writeValueAsString(message); HttpRequest request = this.requestBuilder.uri(URI.create(this.baseUri + endpoint)) - .header("Content-Type", "application/json") .POST(HttpRequest.BodyPublishers.ofString(jsonText)) .build(); diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java index 15749d4ff..fdff4b777 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java @@ -15,7 +15,7 @@ * * @author Christian Tzolov */ -@Timeout(15) // Giving extra time beyond the client timeout +@Timeout(15) class HttpSseMcpAsyncClientTests extends AbstractMcpAsyncClientTests { String host = "http://localhost:3004"; @@ -29,7 +29,7 @@ class HttpSseMcpAsyncClientTests extends AbstractMcpAsyncClientTests { @Override protected McpClientTransport createMcpTransport() { - return new HttpClientSseClientTransport(host); + return HttpClientSseClientTransport.builder(host).build(); } @Override diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java index 067f92957..204cf2984 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java @@ -29,7 +29,7 @@ class HttpSseMcpSyncClientTests extends AbstractMcpSyncClientTests { @Override protected McpClientTransport createMcpTransport() { - return new HttpClientSseClientTransport(host); + return HttpClientSseClientTransport.builder(host).build(); } @Override diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java index 294056fbe..e5178c0ee 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java @@ -4,9 +4,15 @@ package io.modelcontextprotocol.client.transport; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; import java.time.Duration; import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; import java.util.function.Function; import io.modelcontextprotocol.spec.McpSchema; @@ -26,6 +32,8 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; +import com.fasterxml.jackson.databind.ObjectMapper; + /** * Tests for the {@link HttpClientSseClientTransport} class. * @@ -51,8 +59,8 @@ static class TestHttpClientSseClientTransport extends HttpClientSseClientTranspo private Sinks.Many> events = Sinks.many().unicast().onBackpressureBuffer(); - public TestHttpClientSseClientTransport(String baseUri) { - super(baseUri); + public TestHttpClientSseClientTransport(final String baseUri) { + super(HttpClient.newHttpClient(), HttpRequest.newBuilder(), baseUri, "/sse", new ObjectMapper()); } public int getInboundMessageCount() { @@ -191,13 +199,14 @@ void testGracefulShutdown() { StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); // Message count should remain 0 after shutdown - assertThat(transport.getInboundMessageCount()).isEqualTo(0); + assertThat(transport.getInboundMessageCount()).isZero(); } @Test void testRetryBehavior() { // Create a client that simulates connection failures - HttpClientSseClientTransport failingTransport = new HttpClientSseClientTransport("http://non-existent-host"); + HttpClientSseClientTransport failingTransport = HttpClientSseClientTransport.builder("http://non-existent-host") + .build(); // Verify that the transport attempts to reconnect StepVerifier.create(Mono.delay(Duration.ofSeconds(2))).expectNextCount(1).verifyComplete(); @@ -275,4 +284,84 @@ void testMessageOrderPreservation() { assertThat(transport.getInboundMessageCount()).isEqualTo(3); } + @Test + void testCustomizeClient() { + // Create an atomic boolean to verify the customizer was called + AtomicBoolean customizerCalled = new AtomicBoolean(false); + + // Create a transport with the customizer + HttpClientSseClientTransport customizedTransport = HttpClientSseClientTransport.builder(host) + .customizeClient(builder -> { + builder.version(HttpClient.Version.HTTP_2); + customizerCalled.set(true); + }) + .build(); + + // Verify the customizer was called + assertThat(customizerCalled.get()).isTrue(); + + // Clean up + customizedTransport.closeGracefully().block(); + } + + @Test + void testCustomizeRequest() { + // Create an atomic boolean to verify the customizer was called + AtomicBoolean customizerCalled = new AtomicBoolean(false); + + // Create a reference to store the custom header value + AtomicReference headerName = new AtomicReference<>(); + AtomicReference headerValue = new AtomicReference<>(); + + // Create a transport with the customizer + HttpClientSseClientTransport customizedTransport = HttpClientSseClientTransport.builder(host) + // Create a request customizer that adds a custom header + .customizeRequest(builder -> { + builder.header("X-Custom-Header", "test-value"); + customizerCalled.set(true); + + // Create a new request to verify the header was set + HttpRequest request = builder.uri(URI.create("http://example.com")).build(); + headerName.set("X-Custom-Header"); + headerValue.set(request.headers().firstValue("X-Custom-Header").orElse(null)); + }) + .build(); + + // Verify the customizer was called + assertThat(customizerCalled.get()).isTrue(); + + // Verify the header was set correctly + assertThat(headerName.get()).isEqualTo("X-Custom-Header"); + assertThat(headerValue.get()).isEqualTo("test-value"); + + // Clean up + customizedTransport.closeGracefully().block(); + } + + @Test + void testChainedCustomizations() { + // Create atomic booleans to verify both customizers were called + AtomicBoolean clientCustomizerCalled = new AtomicBoolean(false); + AtomicBoolean requestCustomizerCalled = new AtomicBoolean(false); + + // Create a transport with both customizers chained + HttpClientSseClientTransport customizedTransport = HttpClientSseClientTransport.builder(host) + .customizeClient(builder -> { + builder.connectTimeout(Duration.ofSeconds(30)); + clientCustomizerCalled.set(true); + }) + .customizeRequest(builder -> { + builder.header("X-Api-Key", "test-api-key"); + requestCustomizerCalled.set(true); + }) + .build(); + + // Verify both customizers were called + assertThat(clientCustomizerCalled.get()).isTrue(); + assertThat(requestCustomizerCalled.get()).isTrue(); + + // Clean up + customizedTransport.closeGracefully().block(); + } + } From 391ec19fdc346c6d0ebf369f692c370a48339d3d Mon Sep 17 00:00:00 2001 From: Christian Tzolov <1351573+tzolov@users.noreply.github.com> Date: Thu, 10 Apr 2025 12:26:29 +0200 Subject: [PATCH 021/205] refactor: change notification params type from Map to Object (#137) * refactor: change notification params type from Map to Object This change generalizes the parameter type for notification methods across the MCP framework, allowing for more flexible parameter passing. Instead of requiring parameters to be structured as a Map, the API now accepts any Object as parameters. The primary motivation is to simplify client usage by allowing direct passing of strongly-typed objects without requiring conversion to a Map first, as demonstrated in the McpAsyncServer logging notification implementation. Affected components: - McpSession interface and implementations - McpServerTransportProvider interface and implementations - McpSchema JSONRPCNotification record --------- Signed-off-by: Christian Tzolov --- .../transport/WebFluxSseServerTransportProvider.java | 2 +- .../server/transport/WebMvcSseServerTransportProvider.java | 2 +- .../io/modelcontextprotocol/server/McpAsyncServer.java | 7 ++----- .../transport/HttpServletSseServerTransportProvider.java | 2 +- .../server/transport/StdioServerTransportProvider.java | 2 +- .../io/modelcontextprotocol/spec/McpClientSession.java | 2 +- .../main/java/io/modelcontextprotocol/spec/McpSchema.java | 2 +- .../io/modelcontextprotocol/spec/McpServerSession.java | 2 +- .../spec/McpServerTransportProvider.java | 4 ++-- .../main/java/io/modelcontextprotocol/spec/McpSession.java | 4 ++-- .../MockMcpServerTransportProvider.java | 2 +- 11 files changed, 14 insertions(+), 17 deletions(-) diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java index df8dd0211..be30bd72f 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java @@ -188,7 +188,7 @@ public void setSessionFactory(McpServerSession.Factory sessionFactory) { * errors if any session fails to receive the message */ @Override - public Mono notifyClients(String method, Map params) { + public Mono notifyClients(String method, Object params) { if (sessions.isEmpty()) { logger.debug("No active sessions to broadcast message to"); return Mono.empty(); diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java index fa2e357f9..7bd1aa6c9 100644 --- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java @@ -179,7 +179,7 @@ public void setSessionFactory(McpServerSession.Factory sessionFactory) { * @return A Mono that completes when the broadcast attempt is finished */ @Override - public Mono notifyClients(String method, Map params) { + public Mono notifyClients(String method, Object params) { if (sessions.isEmpty()) { logger.debug("No active sessions to broadcast message to"); return Mono.empty(); diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java index df9386685..ec2a04c9e 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -669,15 +669,12 @@ public Mono loggingNotification(LoggingMessageNotification loggingMessageN return Mono.error(new McpError("Logging message must not be null")); } - Map params = this.objectMapper.convertValue(loggingMessageNotification, - new TypeReference>() { - }); - if (loggingMessageNotification.level().level() < minLoggingLevel.level()) { return Mono.empty(); } - return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_MESSAGE, params); + return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_MESSAGE, + loggingMessageNotification); } private McpServerSession.RequestHandler setLoggerRequestHandler() { diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java index e52fc88b7..afdbff472 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java @@ -160,7 +160,7 @@ public void setSessionFactory(McpServerSession.Factory sessionFactory) { * @return A Mono that completes when the broadcast attempt is finished */ @Override - public Mono notifyClients(String method, Map params) { + public Mono notifyClients(String method, Object params) { if (sessions.isEmpty()) { logger.debug("No active sessions to broadcast message to"); return Mono.empty(); diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java index a8b980e90..819da9777 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java @@ -99,7 +99,7 @@ public void setSessionFactory(McpServerSession.Factory sessionFactory) { } @Override - public Mono notifyClients(String method, Map params) { + public Mono notifyClients(String method, Object params) { if (this.session == null) { return Mono.error(new McpError("No session to close")); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java index 719a78001..0895e02b0 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java @@ -258,7 +258,7 @@ public Mono sendRequest(String method, Object requestParams, TypeReferenc * @return A Mono that completes when the notification is sent */ @Override - public Mono sendNotification(String method, Map params) { + public Mono sendNotification(String method, Object params) { McpSchema.JSONRPCNotification jsonrpcNotification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, method, params); return this.transport.sendMessage(jsonrpcNotification); diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java index e38403c32..4c596b628 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java @@ -191,7 +191,7 @@ public record JSONRPCRequest( // @formatter:off public record JSONRPCNotification( // @formatter:off @JsonProperty("jsonrpc") String jsonrpc, @JsonProperty("method") String method, - @JsonProperty("params") Map params) implements JSONRPCMessage { + @JsonProperty("params") Object params) implements JSONRPCMessage { } // @formatter:on @JsonInclude(JsonInclude.Include.NON_ABSENT) diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java index bcdf22486..46014af8d 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java @@ -132,7 +132,7 @@ public Mono sendRequest(String method, Object requestParams, TypeReferenc } @Override - public Mono sendNotification(String method, Map params) { + public Mono sendNotification(String method, Object params) { McpSchema.JSONRPCNotification jsonrpcNotification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, method, params); return this.transport.sendMessage(jsonrpcNotification); diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java index dba8cc43f..5fdbd7ab6 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java @@ -42,11 +42,11 @@ public interface McpServerTransportProvider { /** * Sends a notification to all connected clients. * @param method the name of the notification method to be called on the clients - * @param params a map of parameters to be sent with the notification + * @param params parameters to be sent with the notification * @return a Mono that completes when the notification has been broadcast * @see McpSession#sendNotification(String, Map) */ - Mono notifyClients(String method, Map params); + Mono notifyClients(String method, Object params); /** * Immediately closes all the transports with connected clients and releases any diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSession.java index b97c3ccc4..473a860c2 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSession.java @@ -63,10 +63,10 @@ default Mono sendNotification(String method) { * parameters with the notification. *

* @param method the name of the notification method to be sent to the counterparty - * @param params a map of parameters to be sent with the notification + * @param params parameters to be sent with the notification * @return a Mono that completes when the notification has been sent */ - Mono sendNotification(String method, Map params); + Mono sendNotification(String method, Object params); /** * Closes the session and releases any associated resources asynchronously. diff --git a/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java b/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java index 3fb19180b..20a8c0cf5 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java +++ b/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java @@ -47,7 +47,7 @@ public void setSessionFactory(Factory sessionFactory) { } @Override - public Mono notifyClients(String method, Map params) { + public Mono notifyClients(String method, Object params) { return session.sendNotification(method, params); } From 2895d1589ac3c81366eccfc584c6c733d5846127 Mon Sep 17 00:00:00 2001 From: Christian Tzolov <1351573+tzolov@users.noreply.github.com> Date: Thu, 10 Apr 2025 13:18:55 +0200 Subject: [PATCH 022/205] fix: Add null check for session in WebFluxSseServerTransportProvider (#138) Add error handling to return a 404 NOT_FOUND response when a request is made with a non-existent session ID. This prevents potential NullPointerExceptions when processing requests with invalid session IDs. Signed-off-by: Christian Tzolov --- .../server/transport/WebFluxSseServerTransportProvider.java | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java index be30bd72f..eed8a53af 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java @@ -306,6 +306,11 @@ private Mono handleMessage(ServerRequest request) { McpServerSession session = sessions.get(request.queryParam("sessionId").get()); + if (session == null) { + return ServerResponse.status(HttpStatus.NOT_FOUND) + .bodyValue(new McpError("Session not found: " + request.queryParam("sessionId").get())); + } + return request.bodyToMono(String.class).flatMap(body -> { try { McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body); From c88ac937f3e195c7e767c61e5024737c3417ad72 Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Wed, 9 Apr 2025 11:15:16 +0200 Subject: [PATCH 023/205] feat(mcp): refactor logging to use exchange for targeted client notifications (#132) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Refactors the MCP logging system to use the exchange mechanism for sending logging notifications only to specific client sessions rather than broadcasting to all clients. - Move logging notification delivery from server-wide broadcast to per-session exchange - Implement per-session minimum logging level tracking and filtering - Add proper logging level filtering at the exchange level - Change setLoggingLevel from notification to request/response pattern (breaking change) - Deprecate global server.loggingNotification in favor of exchange.loggingNotification - Add SetLevelRequest record to McpSchema - Add integration test demonstrating filtered logging notifications Resolves #131 Signed-off-by: Christian Tzolov Co-authored-by: Dariusz Jędrzejczyk --- .../WebFluxSseIntegrationTests.java | 356 +++++++++++------ .../server/WebMvcSseIntegrationTests.java | 260 +++++++------ .../client/AbstractMcpAsyncClientTests.java | 14 +- .../server/AbstractMcpAsyncServerTests.java | 49 --- .../server/AbstractMcpSyncServerTests.java | 49 --- .../client/McpAsyncClient.java | 7 +- .../client/McpSyncClient.java | 1 - .../server/McpAsyncServer.java | 31 +- .../server/McpAsyncServerExchange.java | 44 +++ .../server/McpSyncServer.java | 13 +- .../server/McpSyncServerExchange.java | 17 +- .../modelcontextprotocol/spec/McpSchema.java | 5 + .../client/AbstractMcpAsyncClientTests.java | 14 +- .../server/AbstractMcpAsyncServerTests.java | 49 --- .../server/AbstractMcpSyncServerTests.java | 49 --- ...ervletSseServerCustomContextPathTests.java | 11 +- ...rverTransportProviderIntegrationTests.java | 365 ++++++++++++------ 17 files changed, 721 insertions(+), 613 deletions(-) diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java index ac487b6f5..d71fe1ab0 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java @@ -4,6 +4,7 @@ package io.modelcontextprotocol; import java.time.Duration; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; @@ -111,27 +112,28 @@ void testCreateMessageWithoutSamplingCapabilities(String clientType) { return Mono.just(mock(CallToolResult.class)); }); - McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").tools(tool).build(); + var server = McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").tools(tool).build(); - // Create client without sampling capabilities - var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")).build(); + try (var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) + .build();) { - assertThat(client.initialize()).isNotNull(); + assertThat(client.initialize()).isNotNull(); - try { - client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - } - catch (McpError e) { - assertThat(e).isInstanceOf(McpError.class) - .hasMessage("Client must be configured with sampling capabilities"); + try { + client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be configured with sampling capabilities"); + } } + server.close(); } @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "httpclient", "webflux" }) void testCreateMessageSuccess(String clientType) throws InterruptedException { - // Client var clientBuilder = clientBuilders.get(clientType); Function samplingHandler = request -> { @@ -142,13 +144,6 @@ void testCreateMessageSuccess(String clientType) throws InterruptedException { CreateMessageResult.StopReason.STOP_SEQUENCE); }; - var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().sampling().build()) - .sampling(samplingHandler) - .build(); - - // Server - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); @@ -183,15 +178,19 @@ void testCreateMessageSuccess(String clientType) throws InterruptedException { .tools(tool) .build(); - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build()) { - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - mcpClient.close(); + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + } mcpServer.close(); } @@ -206,41 +205,42 @@ void testRootsSuccess(String clientType) { List roots = List.of(new Root("uri1://", "root1"), new Root("uri2://", "root2")); AtomicReference> rootsRef = new AtomicReference<>(); + var mcpServer = McpServer.sync(mcpServerTransportProvider) .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) .build(); - var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) .roots(roots) - .build(); + .build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); - assertThat(rootsRef.get()).isNull(); + assertThat(rootsRef.get()).isNull(); - mcpClient.rootsListChangedNotification(); + mcpClient.rootsListChangedNotification(); - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(roots); - }); + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(roots); + }); - // Remove a root - mcpClient.removeRoot(roots.get(0).uri()); + // Remove a root + mcpClient.removeRoot(roots.get(0).uri()); - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(roots.get(1))); - }); + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(roots.get(1))); + }); - // Add a new root - var root3 = new Root("uri3://", "root3"); - mcpClient.addRoot(root3); + // Add a new root + var root3 = new Root("uri3://", "root3"); + mcpClient.addRoot(root3); - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(roots.get(1), root3)); - }); + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(roots.get(1), root3)); + }); + } - mcpClient.close(); mcpServer.close(); } @@ -261,21 +261,21 @@ void testRootsWithoutCapability(String clientType) { var mcpServer = McpServer.sync(mcpServerTransportProvider).rootsChangeHandler((exchange, rootsUpdate) -> { }).tools(tool).build(); - // Create client without roots capability - // No roots capability - var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()).build(); + try ( + // Create client without roots capability + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()).build()) { - assertThat(mcpClient.initialize()).isNotNull(); + assertThat(mcpClient.initialize()).isNotNull(); - // Attempt to list roots should fail - try { - mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - } - catch (McpError e) { - assertThat(e).isInstanceOf(McpError.class).hasMessage("Roots not supported"); + // Attempt to list roots should fail + try { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class).hasMessage("Roots not supported"); + } } - mcpClient.close(); mcpServer.close(); } @@ -285,30 +285,31 @@ void testRootsNotifciationWithEmptyRootsList(String clientType) { var clientBuilder = clientBuilders.get(clientType); AtomicReference> rootsRef = new AtomicReference<>(); + var mcpServer = McpServer.sync(mcpServerTransportProvider) .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) .build(); - var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) .roots(List.of()) // Empty roots list - .build(); + .build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + assertThat(mcpClient.initialize()).isNotNull(); - mcpClient.rootsListChangedNotification(); + mcpClient.rootsListChangedNotification(); - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).isEmpty(); - }); + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).isEmpty(); + }); + } - mcpClient.close(); mcpServer.close(); } @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "httpclient", "webflux" }) void testRootsWithMultipleHandlers(String clientType) { + var clientBuilder = clientBuilders.get(clientType); List roots = List.of(new Root("uri1://", "root1")); @@ -321,21 +322,21 @@ void testRootsWithMultipleHandlers(String clientType) { .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef2.set(rootsUpdate)) .build(); - var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) .roots(roots) - .build(); + .build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); - mcpClient.rootsListChangedNotification(); + mcpClient.rootsListChangedNotification(); - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef1.get()).containsAll(roots); - assertThat(rootsRef2.get()).containsAll(roots); - }); + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef1.get()).containsAll(roots); + assertThat(rootsRef2.get()).containsAll(roots); + }); + } - mcpClient.close(); mcpServer.close(); } @@ -348,28 +349,26 @@ void testRootsServerCloseWithActiveSubscription(String clientType) { List roots = List.of(new Root("uri1://", "root1")); AtomicReference> rootsRef = new AtomicReference<>(); + var mcpServer = McpServer.sync(mcpServerTransportProvider) .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) .build(); - var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) .roots(roots) - .build(); + .build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); - mcpClient.rootsListChangedNotification(); + mcpClient.rootsListChangedNotification(); - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(roots); - }); + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(roots); + }); + } - // Close server while subscription is active mcpServer.close(); - - // Verify client can handle server closure gracefully - mcpClient.close(); } // --------------------------------------- @@ -378,9 +377,9 @@ void testRootsServerCloseWithActiveSubscription(String clientType) { String emptyJsonSchema = """ { - "$schema": "http://json-schema.org/draft-07/schema#", - "type": "object", - "properties": {} + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": {} } """; @@ -408,19 +407,19 @@ void testToolCallSuccess(String clientType) { .tools(tool1) .build(); - var mcpClient = clientBuilder.build(); + try (var mcpClient = clientBuilder.build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); - assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + } - mcpClient.close(); mcpServer.close(); } @@ -443,13 +442,14 @@ void testToolListChangeHandlingSuccess(String clientType) { return callResponse; }); + AtomicReference> rootsRef = new AtomicReference<>(); + var mcpServer = McpServer.sync(mcpServerTransportProvider) .capabilities(ServerCapabilities.builder().tools(true).build()) .tools(tool1) .build(); - AtomicReference> rootsRef = new AtomicReference<>(); - var mcpClient = clientBuilder.toolsChangeConsumer(toolsUpdate -> { + try (var mcpClient = clientBuilder.toolsChangeConsumer(toolsUpdate -> { // perform a blocking call to a remote service String response = RestClient.create() .get() @@ -458,39 +458,40 @@ void testToolListChangeHandlingSuccess(String clientType) { .body(String.class); assertThat(response).isNotBlank(); rootsRef.set(toolsUpdate); - }).build(); + }).build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); - assertThat(rootsRef.get()).isNull(); + assertThat(rootsRef.get()).isNull(); - assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); - mcpServer.notifyToolsListChanged(); + mcpServer.notifyToolsListChanged(); - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(tool1.tool())); - }); + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(tool1.tool())); + }); - // Remove a tool - mcpServer.removeTool("tool1"); + // Remove a tool + mcpServer.removeTool("tool1"); - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).isEmpty(); - }); + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).isEmpty(); + }); - // Add a new tool - McpServerFeatures.SyncToolSpecification tool2 = new McpServerFeatures.SyncToolSpecification( - new McpSchema.Tool("tool2", "tool2 description", emptyJsonSchema), (exchange, request) -> callResponse); + // Add a new tool + McpServerFeatures.SyncToolSpecification tool2 = new McpServerFeatures.SyncToolSpecification( + new McpSchema.Tool("tool2", "tool2 description", emptyJsonSchema), + (exchange, request) -> callResponse); - mcpServer.addTool(tool2); + mcpServer.addTool(tool2); - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(tool2.tool())); - }); + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(tool2.tool())); + }); + } - mcpClient.close(); mcpServer.close(); } @@ -502,12 +503,115 @@ void testInitialize(String clientType) { var mcpServer = McpServer.sync(mcpServerTransportProvider).build(); - var mcpClient = clientBuilder.build(); + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + } + + mcpServer.close(); + } + + // --------------------------------------- + // Logging Tests + // --------------------------------------- + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testLoggingNotification(String clientType) { + // Create a list to store received logging notifications + List receivedNotifications = new ArrayList<>(); - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + var clientBuilder = clientBuilders.get(clientType); + + // Create server with a tool that sends logging notifications + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("logging-test", "Test logging notifications", emptyJsonSchema), + (exchange, request) -> { + + // Create and send notifications with different levels + + //@formatter:off + return exchange // This should be filtered out (DEBUG < NOTICE) + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.DEBUG) + .logger("test-logger") + .data("Debug message") + .build()) + .then(exchange // This should be sent (NOTICE >= NOTICE) + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.NOTICE) + .logger("test-logger") + .data("Notice message") + .build())) + .then(exchange // This should be sent (ERROR > NOTICE) + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.ERROR) + .logger("test-logger") + .data("Error message") + .build())) + .then(exchange // This should be filtered out (INFO < NOTICE) + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.INFO) + .logger("test-logger") + .data("Another info message") + .build())) + .then(exchange // This should be sent (ERROR >= NOTICE) + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.ERROR) + .logger("test-logger") + .data("Another error message") + .build())) + .thenReturn(new CallToolResult("Logging test completed", false)); + //@formatter:on + }); - mcpClient.close(); + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().logging().tools(true).build()) + .tools(tool) + .build(); + + try ( + // Create client with logging notification handler + var mcpClient = clientBuilder.loggingConsumer(notification -> { + receivedNotifications.add(notification); + }).build()) { + + // Initialize client + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Set minimum logging level to NOTICE + mcpClient.setLoggingLevel(McpSchema.LoggingLevel.NOTICE); + + // Call the tool that sends logging notifications + CallToolResult result = mcpClient.callTool(new McpSchema.CallToolRequest("logging-test", Map.of())); + assertThat(result).isNotNull(); + assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Logging test completed"); + + // Wait for notifications to be processed + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + + // Should have received 3 notifications (1 NOTICE and 2 ERROR) + assertThat(receivedNotifications).hasSize(3); + + // First notification should be NOTICE level + assertThat(receivedNotifications.get(0).level()).isEqualTo(McpSchema.LoggingLevel.NOTICE); + assertThat(receivedNotifications.get(0).logger()).isEqualTo("test-logger"); + assertThat(receivedNotifications.get(0).data()).isEqualTo("Notice message"); + + // Second notification should be ERROR level + assertThat(receivedNotifications.get(1).level()).isEqualTo(McpSchema.LoggingLevel.ERROR); + assertThat(receivedNotifications.get(1).logger()).isEqualTo("test-logger"); + assertThat(receivedNotifications.get(1).data()).isEqualTo("Error message"); + + // Third notification should be ERROR level + assertThat(receivedNotifications.get(2).level()).isEqualTo(McpSchema.LoggingLevel.ERROR); + assertThat(receivedNotifications.get(2).logger()).isEqualTo("test-logger"); + assertThat(receivedNotifications.get(2).data()).isEqualTo("Another error message"); + }); + } mcpServer.close(); } diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java index d5c9f90ff..be01365a1 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java @@ -125,27 +125,34 @@ void testCreateMessageWithoutSamplingCapabilities() { return Mono.just(mock(CallToolResult.class)); }); - McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").tools(tool).build(); + //@formatter:off + var server = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .tools(tool) + .build(); + + try ( + // Create client without sampling capabilities + var client = clientBuilder + .clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) + .build()) {//@formatter:on + + assertThat(client.initialize()).isNotNull(); - // Create client without sampling capabilities - var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")).build(); - - assertThat(client.initialize()).isNotNull(); - - try { - client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - } - catch (McpError e) { - assertThat(e).isInstanceOf(McpError.class) - .hasMessage("Client must be configured with sampling capabilities"); + try { + client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be configured with sampling capabilities"); + } } + server.close(); } @Test void testCreateMessageSuccess() throws InterruptedException { - // Client - Function samplingHandler = request -> { assertThat(request.messages()).hasSize(1); assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); @@ -154,13 +161,6 @@ void testCreateMessageSuccess() throws InterruptedException { CreateMessageResult.StopReason.STOP_SEQUENCE); }; - var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().sampling().build()) - .sampling(samplingHandler) - .build(); - - // Server - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); @@ -190,19 +190,25 @@ void testCreateMessageSuccess() throws InterruptedException { return Mono.just(callResponse); }); + //@formatter:off var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .tools(tool) - .build(); + .serverInfo("test-server", "1.0.0") + .tools(tool) + .build(); - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + try ( + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build()) {//@formatter:on - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); - assertThat(response).isNotNull().isEqualTo(callResponse); + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - mcpClient.close(); + assertThat(response).isNotNull().isEqualTo(callResponse); + } mcpServer.close(); } @@ -214,41 +220,42 @@ void testRootsSuccess() { List roots = List.of(new Root("uri1://", "root1"), new Root("uri2://", "root2")); AtomicReference> rootsRef = new AtomicReference<>(); + var mcpServer = McpServer.sync(mcpServerTransportProvider) .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) .build(); - var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) .roots(roots) - .build(); + .build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); - assertThat(rootsRef.get()).isNull(); + assertThat(rootsRef.get()).isNull(); - mcpClient.rootsListChangedNotification(); + mcpClient.rootsListChangedNotification(); - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(roots); - }); + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(roots); + }); - // Remove a root - mcpClient.removeRoot(roots.get(0).uri()); + // Remove a root + mcpClient.removeRoot(roots.get(0).uri()); - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(roots.get(1))); - }); + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(roots.get(1))); + }); - // Add a new root - var root3 = new Root("uri3://", "root3"); - mcpClient.addRoot(root3); + // Add a new root + var root3 = new Root("uri3://", "root3"); + mcpClient.addRoot(root3); - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(roots.get(1), root3)); - }); + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(roots.get(1), root3)); + }); + } - mcpClient.close(); mcpServer.close(); } @@ -266,21 +273,22 @@ void testRootsWithoutCapability() { var mcpServer = McpServer.sync(mcpServerTransportProvider).rootsChangeHandler((exchange, rootsUpdate) -> { }).tools(tool).build(); - // Create client without roots capability - // No roots capability - var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()).build(); + try ( + // Create client without roots capability + // No roots capability + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()).build()) { - assertThat(mcpClient.initialize()).isNotNull(); + assertThat(mcpClient.initialize()).isNotNull(); - // Attempt to list roots should fail - try { - mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - } - catch (McpError e) { - assertThat(e).isInstanceOf(McpError.class).hasMessage("Roots not supported"); + // Attempt to list roots should fail + try { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class).hasMessage("Roots not supported"); + } } - mcpClient.close(); mcpServer.close(); } @@ -292,20 +300,20 @@ void testRootsNotifciationWithEmptyRootsList() { .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) .build(); - var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) .roots(List.of()) // Empty roots list - .build(); + .build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); - mcpClient.rootsListChangedNotification(); + mcpClient.rootsListChangedNotification(); - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).isEmpty(); - }); + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).isEmpty(); + }); + } - mcpClient.close(); mcpServer.close(); } @@ -321,20 +329,20 @@ void testRootsWithMultipleHandlers() { .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef2.set(rootsUpdate)) .build(); - var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) .roots(roots) - .build(); + .build()) { - assertThat(mcpClient.initialize()).isNotNull(); + assertThat(mcpClient.initialize()).isNotNull(); - mcpClient.rootsListChangedNotification(); + mcpClient.rootsListChangedNotification(); - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef1.get()).containsAll(roots); - assertThat(rootsRef2.get()).containsAll(roots); - }); + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef1.get()).containsAll(roots); + assertThat(rootsRef2.get()).containsAll(roots); + }); + } - mcpClient.close(); mcpServer.close(); } @@ -343,28 +351,26 @@ void testRootsServerCloseWithActiveSubscription() { List roots = List.of(new Root("uri1://", "root1")); AtomicReference> rootsRef = new AtomicReference<>(); + var mcpServer = McpServer.sync(mcpServerTransportProvider) .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) .build(); - var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) .roots(roots) - .build(); + .build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); - mcpClient.rootsListChangedNotification(); + mcpClient.rootsListChangedNotification(); - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(roots); - }); + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(roots); + }); + } - // Close server while subscription is active mcpServer.close(); - - // Verify client can handle server closure gracefully - mcpClient.close(); } // --------------------------------------- @@ -400,18 +406,18 @@ void testToolCallSuccess() { .tools(tool1) .build(); - var mcpClient = clientBuilder.build(); + try (var mcpClient = clientBuilder.build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); - assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - assertThat(response).isNotNull().isEqualTo(callResponse); + assertThat(response).isNotNull().isEqualTo(callResponse); + } - mcpClient.close(); mcpServer.close(); } @@ -431,13 +437,14 @@ void testToolListChangeHandlingSuccess() { return callResponse; }); + AtomicReference> rootsRef = new AtomicReference<>(); + var mcpServer = McpServer.sync(mcpServerTransportProvider) .capabilities(ServerCapabilities.builder().tools(true).build()) .tools(tool1) .build(); - AtomicReference> rootsRef = new AtomicReference<>(); - var mcpClient = clientBuilder.toolsChangeConsumer(toolsUpdate -> { + try (var mcpClient = clientBuilder.toolsChangeConsumer(toolsUpdate -> { // perform a blocking call to a remote service String response = RestClient.create() .get() @@ -446,39 +453,40 @@ void testToolListChangeHandlingSuccess() { .body(String.class); assertThat(response).isNotBlank(); rootsRef.set(toolsUpdate); - }).build(); + }).build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); - assertThat(rootsRef.get()).isNull(); + assertThat(rootsRef.get()).isNull(); - assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); - mcpServer.notifyToolsListChanged(); + mcpServer.notifyToolsListChanged(); - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(tool1.tool())); - }); + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(tool1.tool())); + }); - // Remove a tool - mcpServer.removeTool("tool1"); + // Remove a tool + mcpServer.removeTool("tool1"); - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).isEmpty(); - }); + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).isEmpty(); + }); - // Add a new tool - McpServerFeatures.SyncToolSpecification tool2 = new McpServerFeatures.SyncToolSpecification( - new McpSchema.Tool("tool2", "tool2 description", emptyJsonSchema), (exchange, request) -> callResponse); + // Add a new tool + McpServerFeatures.SyncToolSpecification tool2 = new McpServerFeatures.SyncToolSpecification( + new McpSchema.Tool("tool2", "tool2 description", emptyJsonSchema), + (exchange, request) -> callResponse); - mcpServer.addTool(tool2); + mcpServer.addTool(tool2); - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(tool2.tool())); - }); + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(tool2.tool())); + }); + } - mcpClient.close(); mcpServer.close(); } @@ -487,12 +495,12 @@ void testInitialize() { var mcpServer = McpServer.sync(mcpServerTransportProvider).build(); - var mcpClient = clientBuilder.build(); + try (var mcpClient = clientBuilder.build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + } - mcpClient.close(); mcpServer.close(); } diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java index 713563519..5452c8eac 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java @@ -31,6 +31,7 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; @@ -453,15 +454,10 @@ void testLoggingLevelsWithoutInitialization() { @Test void testLoggingLevels() { withClient(createMcpTransport(), mcpAsyncClient -> { - Mono testAllLevels = mcpAsyncClient.initialize().then(Mono.defer(() -> { - Mono chain = Mono.empty(); - for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { - chain = chain.then(mcpAsyncClient.setLoggingLevel(level)); - } - return chain; - })); - - StepVerifier.create(testAllLevels).verifyComplete(); + StepVerifier + .create(mcpAsyncClient.initialize() + .thenMany(Flux.fromArray(McpSchema.LoggingLevel.values()).flatMap(mcpAsyncClient::setLoggingLevel))) + .verifyComplete(); }); } diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java index 7bcb9a8b2..a91632c6c 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java @@ -416,53 +416,4 @@ void testRootsChangeHandlers() { .doesNotThrowAnyException(); } - // --------------------------------------- - // Logging Tests - // --------------------------------------- - - @Test - void testLoggingLevels() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().logging().build()) - .build(); - - // Test all logging levels - for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { - var notification = McpSchema.LoggingMessageNotification.builder() - .level(level) - .logger("test-logger") - .data("Test message with level " + level) - .build(); - - StepVerifier.create(mcpAsyncServer.loggingNotification(notification)).verifyComplete(); - } - } - - @Test - void testLoggingWithoutCapability() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().build()) // No logging capability - .build(); - - var notification = McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.INFO) - .logger("test-logger") - .data("Test log message") - .build(); - - StepVerifier.create(mcpAsyncServer.loggingNotification(notification)).verifyComplete(); - } - - @Test - void testLoggingWithNullNotification() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().logging().build()) - .build(); - - StepVerifier.create(mcpAsyncServer.loggingNotification(null)).verifyError(McpError.class); - } - } diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java index 7846e053b..9a63143c9 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java @@ -388,53 +388,4 @@ void testRootsChangeHandlers() { assertThatCode(() -> noConsumersServer.closeGracefully()).doesNotThrowAnyException(); } - // --------------------------------------- - // Logging Tests - // --------------------------------------- - - @Test - void testLoggingLevels() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().logging().build()) - .build(); - - // Test all logging levels - for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { - var notification = McpSchema.LoggingMessageNotification.builder() - .level(level) - .logger("test-logger") - .data("Test message with level " + level) - .build(); - - assertThatCode(() -> mcpSyncServer.loggingNotification(notification)).doesNotThrowAnyException(); - } - } - - @Test - void testLoggingWithoutCapability() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().build()) // No logging capability - .build(); - - var notification = McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.INFO) - .logger("test-logger") - .data("Test log message") - .build(); - - assertThatCode(() -> mcpSyncServer.loggingNotification(notification)).doesNotThrowAnyException(); - } - - @Test - void testLoggingWithNullNotification() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().logging().build()) - .build(); - - assertThatThrownBy(() -> mcpSyncServer.loggingNotification(null)).isInstanceOf(McpError.class); - } - } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java index ce49b0a5e..df099836d 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java @@ -786,10 +786,9 @@ public Mono setLoggingLevel(LoggingLevel loggingLevel) { } return this.withInitializationCheck("setting logging level", initializedResult -> { - String levelName = this.transport.unmarshalFrom(loggingLevel, new TypeReference() { - }); - Map params = Map.of("level", levelName); - return this.mcpSession.sendNotification(McpSchema.METHOD_LOGGING_SET_LEVEL, params); + var params = new McpSchema.SetLevelRequest(loggingLevel); + return this.mcpSession.sendRequest(McpSchema.METHOD_LOGGING_SET_LEVEL, params, new TypeReference() { + }).then(); }); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java index 071d76462..32cf325e9 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java @@ -6,7 +6,6 @@ import java.time.Duration; -import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; import io.modelcontextprotocol.spec.McpSchema.GetPromptRequest; diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java index ec2a04c9e..062de13ed 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -21,6 +21,7 @@ import io.modelcontextprotocol.spec.McpSchema.CallToolResult; import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; +import io.modelcontextprotocol.spec.McpSchema.SetLevelRequest; import io.modelcontextprotocol.spec.McpSchema.Tool; import io.modelcontextprotocol.spec.McpServerSession; import io.modelcontextprotocol.spec.McpServerTransportProvider; @@ -216,11 +217,17 @@ public Mono notifyPromptsListChanged() { // --------------------------------------- /** - * Send a logging message notification to all connected clients. Messages below the - * current minimum logging level will be filtered out. + * This implementation would, incorrectly, broadcast the logging message to all + * connected clients, using a single minLoggingLevel for all of them. Similar to the + * sampling and roots, the logging level should be set per client session and use the + * ServerExchange to send the logging message to the right client. * @param loggingMessageNotification The logging message to send * @return A Mono that completes when the notification has been sent + * @deprecated Use + * {@link McpAsyncServerExchange#loggingNotification(LoggingMessageNotification)} + * instead. */ + @Deprecated public Mono loggingNotification(LoggingMessageNotification loggingMessageNotification) { return this.delegate.loggingNotification(loggingMessageNotification); } @@ -257,6 +264,8 @@ private static class AsyncServerImpl extends McpAsyncServer { private final ConcurrentHashMap prompts = new ConcurrentHashMap<>(); + // FIXME: this field is deprecated and should be remvoed together with the + // broadcasting loggingNotification. private LoggingLevel minLoggingLevel = LoggingLevel.DEBUG; private List protocolVersions = List.of(McpSchema.LATEST_PROTOCOL_VERSION); @@ -677,12 +686,22 @@ public Mono loggingNotification(LoggingMessageNotification loggingMessageN loggingMessageNotification); } - private McpServerSession.RequestHandler setLoggerRequestHandler() { + private McpServerSession.RequestHandler setLoggerRequestHandler() { return (exchange, params) -> { - this.minLoggingLevel = objectMapper.convertValue(params, new TypeReference() { - }); + return Mono.defer(() -> { - return Mono.empty(); + SetLevelRequest newMinLoggingLevel = objectMapper.convertValue(params, + new TypeReference() { + }); + + exchange.setMinLoggingLevel(newMinLoggingLevel.level()); + + // FIXME: this field is deprecated and should be removed together + // with the broadcasting loggingNotification. + this.minLoggingLevel = newMinLoggingLevel.level(); + + return Mono.just(Map.of()); + }); }; } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java index 658628448..889dc66d0 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java @@ -1,9 +1,16 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + package io.modelcontextprotocol.server; import com.fasterxml.jackson.core.type.TypeReference; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; +import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.util.Assert; import reactor.core.publisher.Mono; /** @@ -11,6 +18,7 @@ * exchange provides methods to interact with the client and query its capabilities. * * @author Dariusz Jędrzejczyk + * @author Christian Tzolov */ public class McpAsyncServerExchange { @@ -20,6 +28,8 @@ public class McpAsyncServerExchange { private final McpSchema.Implementation clientInfo; + private volatile LoggingLevel minLoggingLevel = LoggingLevel.INFO; + private static final TypeReference CREATE_MESSAGE_RESULT_TYPE_REF = new TypeReference<>() { }; @@ -101,4 +111,38 @@ public Mono listRoots(String cursor) { LIST_ROOTS_RESULT_TYPE_REF); } + /** + * Send a logging message notification to all connected clients. Messages below the + * current minimum logging level will be filtered out. + * @param loggingMessageNotification The logging message to send + * @return A Mono that completes when the notification has been sent + */ + public Mono loggingNotification(LoggingMessageNotification loggingMessageNotification) { + + if (loggingMessageNotification == null) { + return Mono.error(new McpError("Logging message must not be null")); + } + + return Mono.defer(() -> { + if (this.isNotificationForLevelAllowed(loggingMessageNotification.level())) { + return this.session.sendNotification(McpSchema.METHOD_NOTIFICATION_MESSAGE, loggingMessageNotification); + } + return Mono.empty(); + }); + } + + /** + * Set the minimum logging level for the client. Messages below this level will be + * filtered out. + * @param minLoggingLevel The minimum logging level + */ + void setMinLoggingLevel(LoggingLevel minLoggingLevel) { + Assert.notNull(minLoggingLevel, "minLoggingLevel must not be null"); + this.minLoggingLevel = minLoggingLevel; + } + + private boolean isNotificationForLevelAllowed(LoggingLevel loggingLevel) { + return loggingLevel.level() >= this.minLoggingLevel.level(); + } + } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java index 72eba8b86..bf3104508 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java @@ -4,9 +4,7 @@ package io.modelcontextprotocol.server; -import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; import io.modelcontextprotocol.util.Assert; @@ -151,9 +149,16 @@ public void notifyPromptsListChanged() { } /** - * Send a logging message notification to all clients. - * @param loggingMessageNotification The logging message notification to send + * This implementation would, incorrectly, broadcast the logging message to all + * connected clients, using a single minLoggingLevel for all of them. Similar to the + * sampling and roots, the logging level should be set per client session and use the + * ServerExchange to send the logging message to the right client. + * @param loggingMessageNotification The logging message to send + * @deprecated Use + * {@link McpSyncServerExchange#loggingNotification(LoggingMessageNotification)} + * instead. */ + @Deprecated public void loggingNotification(LoggingMessageNotification loggingMessageNotification) { this.asyncServer.loggingNotification(loggingMessageNotification).block(); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java index f121db552..52360e54b 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java @@ -1,13 +1,19 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + package io.modelcontextprotocol.server; -import com.fasterxml.jackson.core.type.TypeReference; import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; +import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; /** * Represents a synchronous exchange with a Model Context Protocol (MCP) client. The * exchange provides methods to interact with the client and query its capabilities. * * @author Dariusz Jędrzejczyk + * @author Christian Tzolov */ public class McpSyncServerExchange { @@ -75,4 +81,13 @@ public McpSchema.ListRootsResult listRoots(String cursor) { return this.exchange.listRoots(cursor).block(); } + /** + * Send a logging message notification to all connected clients. Messages below the + * current minimum logging level will be filtered out. + * @param loggingMessageNotification The logging message to send + */ + public void loggingNotification(LoggingMessageNotification loggingMessageNotification) { + this.exchange.loggingNotification(loggingMessageNotification).block(); + } + } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java index 4c596b628..e621ac19b 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java @@ -1165,6 +1165,11 @@ public int level() { } // @formatter:on + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record SetLevelRequest(@JsonProperty("level") LoggingLevel level) { + } + // --------------------------- // Autocomplete // --------------------------- diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java index ac7b9e5ec..72b409af9 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java @@ -31,6 +31,7 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; @@ -454,15 +455,10 @@ void testLoggingLevelsWithoutInitialization() { @Test void testLoggingLevels() { withClient(createMcpTransport(), mcpAsyncClient -> { - Mono testAllLevels = mcpAsyncClient.initialize().then(Mono.defer(() -> { - Mono chain = Mono.empty(); - for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { - chain = chain.then(mcpAsyncClient.setLoggingLevel(level)); - } - return chain; - })); - - StepVerifier.create(testAllLevels).verifyComplete(); + StepVerifier + .create(mcpAsyncClient.initialize() + .thenMany(Flux.fromArray(McpSchema.LoggingLevel.values()).flatMap(mcpAsyncClient::setLoggingLevel))) + .verifyComplete(); }); } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java index 4b4fc434f..c7c69b52b 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java @@ -415,53 +415,4 @@ void testRootsChangeHandlers() { .doesNotThrowAnyException(); } - // --------------------------------------- - // Logging Tests - // --------------------------------------- - - @Test - void testLoggingLevels() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().logging().build()) - .build(); - - // Test all logging levels - for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { - var notification = McpSchema.LoggingMessageNotification.builder() - .level(level) - .logger("test-logger") - .data("Test message with level " + level) - .build(); - - StepVerifier.create(mcpAsyncServer.loggingNotification(notification)).verifyComplete(); - } - } - - @Test - void testLoggingWithoutCapability() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().build()) // No logging capability - .build(); - - var notification = McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.INFO) - .logger("test-logger") - .data("Test log message") - .build(); - - StepVerifier.create(mcpAsyncServer.loggingNotification(notification)).verifyComplete(); - } - - @Test - void testLoggingWithNullNotification() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().logging().build()) - .build(); - - StepVerifier.create(mcpAsyncServer.loggingNotification(null)).verifyError(McpError.class); - } - } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java index 17feb36e5..8c9328cc7 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java @@ -387,53 +387,4 @@ void testRootsChangeHandlers() { assertThatCode(() -> noConsumersServer.closeGracefully()).doesNotThrowAnyException(); } - // --------------------------------------- - // Logging Tests - // --------------------------------------- - - @Test - void testLoggingLevels() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().logging().build()) - .build(); - - // Test all logging levels - for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { - var notification = McpSchema.LoggingMessageNotification.builder() - .level(level) - .logger("test-logger") - .data("Test message with level " + level) - .build(); - - assertThatCode(() -> mcpSyncServer.loggingNotification(notification)).doesNotThrowAnyException(); - } - } - - @Test - void testLoggingWithoutCapability() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().build()) // No logging capability - .build(); - - var notification = McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.INFO) - .logger("test-logger") - .data("Test log message") - .build(); - - assertThatCode(() -> mcpSyncServer.loggingNotification(notification)).doesNotThrowAnyException(); - } - - @Test - void testLoggingWithNullNotification() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().logging().build()) - .build(); - - assertThatThrownBy(() -> mcpSyncServer.loggingNotification(null)).isInstanceOf(McpError.class); - } - } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerCustomContextPathTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerCustomContextPathTests.java index 1254e2ad8..212a3c95d 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerCustomContextPathTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerCustomContextPathTests.java @@ -8,7 +8,6 @@ import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; import io.modelcontextprotocol.server.McpServer; import io.modelcontextprotocol.spec.McpSchema; -import org.apache.catalina.Context; import org.apache.catalina.LifecycleException; import org.apache.catalina.LifecycleState; import org.apache.catalina.startup.Tomcat; @@ -78,9 +77,13 @@ public void after() { @Test void testCustomContextPath() { - McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").build(); - var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")).build(); - assertThat(client.initialize()).isNotNull(); + var server = McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").build(); + try (//@formatter:off + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) .build()) { //@formatter:on + + assertThat(client.initialize()).isNotNull(); + } + server.close(); } } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java index e34baf9d6..a7b634824 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java @@ -4,6 +4,7 @@ package io.modelcontextprotocol.server.transport; import java.time.Duration; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.concurrent.atomic.AtomicReference; @@ -44,7 +45,7 @@ public class HttpServletSseServerTransportProviderIntegrationTests { - private static final int PORT = 8185; + private static final int PORT = 8189; private static final String CUSTOM_SSE_ENDPOINT = "/somePath/sse"; @@ -110,27 +111,29 @@ void testCreateMessageWithoutSamplingCapabilities() { return Mono.just(mock(CallToolResult.class)); }); - McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").tools(tool).build(); + var server = McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").tools(tool).build(); - // Create client without sampling capabilities - var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")).build(); + try ( + // Create client without sampling capabilities + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) + .build()) { - assertThat(client.initialize()).isNotNull(); + assertThat(client.initialize()).isNotNull(); - try { - client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - } - catch (McpError e) { - assertThat(e).isInstanceOf(McpError.class) - .hasMessage("Client must be configured with sampling capabilities"); + try { + client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be configured with sampling capabilities"); + } } + server.close(); } @Test void testCreateMessageSuccess() throws InterruptedException { - // Client - Function samplingHandler = request -> { assertThat(request.messages()).hasSize(1); assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); @@ -139,13 +142,6 @@ void testCreateMessageSuccess() throws InterruptedException { CreateMessageResult.StopReason.STOP_SEQUENCE); }; - var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().sampling().build()) - .sampling(samplingHandler) - .build(); - - // Server - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); @@ -180,15 +176,19 @@ void testCreateMessageSuccess() throws InterruptedException { .tools(tool) .build(); - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build()) { - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - mcpClient.close(); + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + } mcpServer.close(); } @@ -200,42 +200,43 @@ void testRootsSuccess() { List roots = List.of(new Root("uri1://", "root1"), new Root("uri2://", "root2")); AtomicReference> rootsRef = new AtomicReference<>(); + var mcpServer = McpServer.sync(mcpServerTransportProvider) .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) .build(); - var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) .roots(roots) - .build(); + .build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); - assertThat(rootsRef.get()).isNull(); + assertThat(rootsRef.get()).isNull(); - mcpClient.rootsListChangedNotification(); + mcpClient.rootsListChangedNotification(); - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(roots); - }); + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(roots); + }); - // Remove a root - mcpClient.removeRoot(roots.get(0).uri()); + // Remove a root + mcpClient.removeRoot(roots.get(0).uri()); - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(roots.get(1))); - }); + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(roots.get(1))); + }); - // Add a new root - var root3 = new Root("uri3://", "root3"); - mcpClient.addRoot(root3); + // Add a new root + var root3 = new Root("uri3://", "root3"); + mcpClient.addRoot(root3); - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(roots.get(1), root3)); - }); + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(roots.get(1), root3)); + }); - mcpClient.close(); - mcpServer.close(); + mcpServer.close(); + } } @Test @@ -252,21 +253,19 @@ void testRootsWithoutCapability() { var mcpServer = McpServer.sync(mcpServerTransportProvider).rootsChangeHandler((exchange, rootsUpdate) -> { }).tools(tool).build(); - // Create client without roots capability - // No roots capability - var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()).build(); + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()).build()) { - assertThat(mcpClient.initialize()).isNotNull(); + assertThat(mcpClient.initialize()).isNotNull(); - // Attempt to list roots should fail - try { - mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - } - catch (McpError e) { - assertThat(e).isInstanceOf(McpError.class).hasMessage("Roots not supported"); + // Attempt to list roots should fail + try { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class).hasMessage("Roots not supported"); + } } - mcpClient.close(); mcpServer.close(); } @@ -278,20 +277,20 @@ void testRootsNotifciationWithEmptyRootsList() { .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) .build(); - var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) .roots(List.of()) // Empty roots list - .build(); + .build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); - mcpClient.rootsListChangedNotification(); + mcpClient.rootsListChangedNotification(); - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).isEmpty(); - }); + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).isEmpty(); + }); + } - mcpClient.close(); mcpServer.close(); } @@ -307,20 +306,20 @@ void testRootsWithMultipleHandlers() { .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef2.set(rootsUpdate)) .build(); - var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) .roots(roots) - .build(); + .build()) { - assertThat(mcpClient.initialize()).isNotNull(); + assertThat(mcpClient.initialize()).isNotNull(); - mcpClient.rootsListChangedNotification(); + mcpClient.rootsListChangedNotification(); - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef1.get()).containsAll(roots); - assertThat(rootsRef2.get()).containsAll(roots); - }); + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef1.get()).containsAll(roots); + assertThat(rootsRef2.get()).containsAll(roots); + }); + } - mcpClient.close(); mcpServer.close(); } @@ -329,28 +328,26 @@ void testRootsServerCloseWithActiveSubscription() { List roots = List.of(new Root("uri1://", "root1")); AtomicReference> rootsRef = new AtomicReference<>(); + var mcpServer = McpServer.sync(mcpServerTransportProvider) .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) .build(); - var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) .roots(roots) - .build(); + .build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); - mcpClient.rootsListChangedNotification(); + mcpClient.rootsListChangedNotification(); - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(roots); - }); + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(roots); + }); + } - // Close server while subscription is active mcpServer.close(); - - // Verify client can handle server closure gracefully - mcpClient.close(); } // --------------------------------------- @@ -386,19 +383,18 @@ void testToolCallSuccess() { .tools(tool1) .build(); - var mcpClient = clientBuilder.build(); + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); - assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + } - mcpClient.close(); mcpServer.close(); } @@ -418,13 +414,14 @@ void testToolListChangeHandlingSuccess() { return callResponse; }); + AtomicReference> rootsRef = new AtomicReference<>(); + var mcpServer = McpServer.sync(mcpServerTransportProvider) .capabilities(ServerCapabilities.builder().tools(true).build()) .tools(tool1) .build(); - AtomicReference> rootsRef = new AtomicReference<>(); - var mcpClient = clientBuilder.toolsChangeConsumer(toolsUpdate -> { + try (var mcpClient = clientBuilder.toolsChangeConsumer(toolsUpdate -> { // perform a blocking call to a remote service String response = RestClient.create() .get() @@ -433,53 +430,167 @@ void testToolListChangeHandlingSuccess() { .body(String.class); assertThat(response).isNotBlank(); rootsRef.set(toolsUpdate); - }).build(); + }).build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); - assertThat(rootsRef.get()).isNull(); + assertThat(rootsRef.get()).isNull(); - assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); - mcpServer.notifyToolsListChanged(); + mcpServer.notifyToolsListChanged(); - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(tool1.tool())); - }); + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(tool1.tool())); + }); - // Remove a tool - mcpServer.removeTool("tool1"); + // Remove a tool + mcpServer.removeTool("tool1"); - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).isEmpty(); - }); + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).isEmpty(); + }); - // Add a new tool - McpServerFeatures.SyncToolSpecification tool2 = new McpServerFeatures.SyncToolSpecification( - new McpSchema.Tool("tool2", "tool2 description", emptyJsonSchema), (exchange, request) -> callResponse); + // Add a new tool + McpServerFeatures.SyncToolSpecification tool2 = new McpServerFeatures.SyncToolSpecification( + new McpSchema.Tool("tool2", "tool2 description", emptyJsonSchema), + (exchange, request) -> callResponse); - mcpServer.addTool(tool2); + mcpServer.addTool(tool2); - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(tool2.tool())); - }); + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(tool2.tool())); + }); + } - mcpClient.close(); mcpServer.close(); } @Test void testInitialize() { - var mcpServer = McpServer.sync(mcpServerTransportProvider).build(); - var mcpClient = clientBuilder.build(); + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + } + + mcpServer.close(); + } - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + // --------------------------------------- + // Logging Tests + // --------------------------------------- + @Test + void testLoggingNotification() { + // Create a list to store received logging notifications + List receivedNotifications = new ArrayList<>(); - mcpClient.close(); + // Create server with a tool that sends logging notifications + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("logging-test", "Test logging notifications", emptyJsonSchema), + (exchange, request) -> { + + // Create and send notifications with different levels + + // This should be filtered out (DEBUG < NOTICE) + exchange + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.DEBUG) + .logger("test-logger") + .data("Debug message") + .build()) + .block(); + + // This should be sent (NOTICE >= NOTICE) + exchange + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.NOTICE) + .logger("test-logger") + .data("Notice message") + .build()) + .block(); + + // This should be sent (ERROR > NOTICE) + exchange + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.ERROR) + .logger("test-logger") + .data("Error message") + .build()) + .block(); + + // This should be filtered out (INFO < NOTICE) + exchange + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.INFO) + .logger("test-logger") + .data("Another info message") + .build()) + .block(); + + // This should be sent (ERROR >= NOTICE) + exchange + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.ERROR) + .logger("test-logger") + .data("Another error message") + .build()) + .block(); + + return Mono.just(new CallToolResult("Logging test completed", false)); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().logging().tools(true).build()) + .tools(tool) + .build(); + try ( + // Create client with logging notification handler + var mcpClient = clientBuilder.loggingConsumer(notification -> { + receivedNotifications.add(notification); + }).build()) { + + // Initialize client + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Set minimum logging level to NOTICE + mcpClient.setLoggingLevel(McpSchema.LoggingLevel.NOTICE); + + // Call the tool that sends logging notifications + CallToolResult result = mcpClient.callTool(new McpSchema.CallToolRequest("logging-test", Map.of())); + assertThat(result).isNotNull(); + assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Logging test completed"); + + // Wait for notifications to be processed + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + + System.out.println("Received notifications: " + receivedNotifications); + + // Should have received 3 notifications (1 NOTICE and 2 ERROR) + assertThat(receivedNotifications).hasSize(3); + + // First notification should be NOTICE level + assertThat(receivedNotifications.get(0).level()).isEqualTo(McpSchema.LoggingLevel.NOTICE); + assertThat(receivedNotifications.get(0).logger()).isEqualTo("test-logger"); + assertThat(receivedNotifications.get(0).data()).isEqualTo("Notice message"); + + // Second notification should be ERROR level + assertThat(receivedNotifications.get(1).level()).isEqualTo(McpSchema.LoggingLevel.ERROR); + assertThat(receivedNotifications.get(1).logger()).isEqualTo("test-logger"); + assertThat(receivedNotifications.get(1).data()).isEqualTo("Error message"); + + // Third notification should be ERROR level + assertThat(receivedNotifications.get(2).level()).isEqualTo(McpSchema.LoggingLevel.ERROR); + assertThat(receivedNotifications.get(2).logger()).isEqualTo("test-logger"); + assertThat(receivedNotifications.get(2).data()).isEqualTo("Another error message"); + }); + } mcpServer.close(); } From 63724f17a4a7d72f8f28b149a5940122b4a5bf02 Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Thu, 10 Apr 2025 16:54:24 +0200 Subject: [PATCH 024/205] refactor(tests): improve notification assertions in WebFluxSseIntegrationTests Replace index-based assertions with content-based lookups using a notification map. This change makes the tests more resilient by removing the dependency on notification order, which is important for asynchronous messaging tests. Signed-off-by: Christian Tzolov --- .../WebFluxSseIntegrationTests.java | 23 +++++++++++-------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java index d71fe1ab0..76f908b8a 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java @@ -10,6 +10,7 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; +import java.util.stream.Collectors; import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.client.McpClient; @@ -596,20 +597,24 @@ void testLoggingNotification(String clientType) { // Should have received 3 notifications (1 NOTICE and 2 ERROR) assertThat(receivedNotifications).hasSize(3); + Map notificationMap = receivedNotifications.stream() + .collect(Collectors.toMap(n -> n.data(), n -> n)); + // First notification should be NOTICE level - assertThat(receivedNotifications.get(0).level()).isEqualTo(McpSchema.LoggingLevel.NOTICE); - assertThat(receivedNotifications.get(0).logger()).isEqualTo("test-logger"); - assertThat(receivedNotifications.get(0).data()).isEqualTo("Notice message"); + assertThat(notificationMap.get("Notice message").level()).isEqualTo(McpSchema.LoggingLevel.NOTICE); + assertThat(notificationMap.get("Notice message").logger()).isEqualTo("test-logger"); + assertThat(notificationMap.get("Notice message").data()).isEqualTo("Notice message"); // Second notification should be ERROR level - assertThat(receivedNotifications.get(1).level()).isEqualTo(McpSchema.LoggingLevel.ERROR); - assertThat(receivedNotifications.get(1).logger()).isEqualTo("test-logger"); - assertThat(receivedNotifications.get(1).data()).isEqualTo("Error message"); + assertThat(notificationMap.get("Error message").level()).isEqualTo(McpSchema.LoggingLevel.ERROR); + assertThat(notificationMap.get("Error message").logger()).isEqualTo("test-logger"); + assertThat(notificationMap.get("Error message").data()).isEqualTo("Error message"); // Third notification should be ERROR level - assertThat(receivedNotifications.get(2).level()).isEqualTo(McpSchema.LoggingLevel.ERROR); - assertThat(receivedNotifications.get(2).logger()).isEqualTo("test-logger"); - assertThat(receivedNotifications.get(2).data()).isEqualTo("Another error message"); + assertThat(notificationMap.get("Another error message").level()) + .isEqualTo(McpSchema.LoggingLevel.ERROR); + assertThat(notificationMap.get("Another error message").logger()).isEqualTo("test-logger"); + assertThat(notificationMap.get("Another error message").data()).isEqualTo("Another error message"); }); } mcpServer.close(); From 2e953c81aa6d0e173801282fc03b01bfb413ff0f Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Thu, 10 Apr 2025 17:16:27 +0200 Subject: [PATCH 025/205] refactor(tests): improve notification assertions in HttpServletSseServerTransportProviderIntegrationTests Replace index-based assertions with content-based lookups using a notification map. This change makes the tests more resilient by removing the dependency on notification order, which is important for asynchronous messaging tests. Signed-off-by: Christian Tzolov --- ...rverTransportProviderIntegrationTests.java | 23 +++++++++++-------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java index a7b634824..135de83fa 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java @@ -9,6 +9,7 @@ import java.util.Map; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; +import java.util.stream.Collectors; import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.client.McpClient; @@ -575,20 +576,24 @@ void testLoggingNotification() { // Should have received 3 notifications (1 NOTICE and 2 ERROR) assertThat(receivedNotifications).hasSize(3); + Map notificationMap = receivedNotifications.stream() + .collect(Collectors.toMap(n -> n.data(), n -> n)); + // First notification should be NOTICE level - assertThat(receivedNotifications.get(0).level()).isEqualTo(McpSchema.LoggingLevel.NOTICE); - assertThat(receivedNotifications.get(0).logger()).isEqualTo("test-logger"); - assertThat(receivedNotifications.get(0).data()).isEqualTo("Notice message"); + assertThat(notificationMap.get("Notice message").level()).isEqualTo(McpSchema.LoggingLevel.NOTICE); + assertThat(notificationMap.get("Notice message").logger()).isEqualTo("test-logger"); + assertThat(notificationMap.get("Notice message").data()).isEqualTo("Notice message"); // Second notification should be ERROR level - assertThat(receivedNotifications.get(1).level()).isEqualTo(McpSchema.LoggingLevel.ERROR); - assertThat(receivedNotifications.get(1).logger()).isEqualTo("test-logger"); - assertThat(receivedNotifications.get(1).data()).isEqualTo("Error message"); + assertThat(notificationMap.get("Error message").level()).isEqualTo(McpSchema.LoggingLevel.ERROR); + assertThat(notificationMap.get("Error message").logger()).isEqualTo("test-logger"); + assertThat(notificationMap.get("Error message").data()).isEqualTo("Error message"); // Third notification should be ERROR level - assertThat(receivedNotifications.get(2).level()).isEqualTo(McpSchema.LoggingLevel.ERROR); - assertThat(receivedNotifications.get(2).logger()).isEqualTo("test-logger"); - assertThat(receivedNotifications.get(2).data()).isEqualTo("Another error message"); + assertThat(notificationMap.get("Another error message").level()) + .isEqualTo(McpSchema.LoggingLevel.ERROR); + assertThat(notificationMap.get("Another error message").logger()).isEqualTo("test-logger"); + assertThat(notificationMap.get("Another error message").data()).isEqualTo("Another error message"); }); } mcpServer.close(); From f348a83e5acef05b6c8807c7000c59098b667d28 Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Thu, 10 Apr 2025 18:06:46 +0200 Subject: [PATCH 026/205] Next development version Signed-off-by: Christian Tzolov --- mcp-bom/pom.xml | 2 +- mcp-spring/mcp-spring-webflux/pom.xml | 6 +++--- mcp-spring/mcp-spring-webmvc/pom.xml | 6 +++--- mcp-test/pom.xml | 4 ++-- mcp/pom.xml | 2 +- pom.xml | 2 +- 6 files changed, 11 insertions(+), 11 deletions(-) diff --git a/mcp-bom/pom.xml b/mcp-bom/pom.xml index 77d55da34..4f24f719f 100644 --- a/mcp-bom/pom.xml +++ b/mcp-bom/pom.xml @@ -7,7 +7,7 @@ io.modelcontextprotocol.sdk mcp-parent - 0.9.0-SNAPSHOT + 0.10.0-SNAPSHOT mcp-bom diff --git a/mcp-spring/mcp-spring-webflux/pom.xml b/mcp-spring/mcp-spring-webflux/pom.xml index 186ade796..63c32a8a8 100644 --- a/mcp-spring/mcp-spring-webflux/pom.xml +++ b/mcp-spring/mcp-spring-webflux/pom.xml @@ -6,7 +6,7 @@ io.modelcontextprotocol.sdk mcp-parent - 0.9.0-SNAPSHOT + 0.10.0-SNAPSHOT ../../pom.xml mcp-spring-webflux @@ -25,13 +25,13 @@ io.modelcontextprotocol.sdk mcp - 0.9.0-SNAPSHOT + 0.10.0-SNAPSHOT io.modelcontextprotocol.sdk mcp-test - 0.9.0-SNAPSHOT + 0.10.0-SNAPSHOT test diff --git a/mcp-spring/mcp-spring-webmvc/pom.xml b/mcp-spring/mcp-spring-webmvc/pom.xml index 67e6b0aee..b59be6a03 100644 --- a/mcp-spring/mcp-spring-webmvc/pom.xml +++ b/mcp-spring/mcp-spring-webmvc/pom.xml @@ -6,7 +6,7 @@ io.modelcontextprotocol.sdk mcp-parent - 0.9.0-SNAPSHOT + 0.10.0-SNAPSHOT ../../pom.xml mcp-spring-webmvc @@ -25,13 +25,13 @@ io.modelcontextprotocol.sdk mcp - 0.9.0-SNAPSHOT + 0.10.0-SNAPSHOT io.modelcontextprotocol.sdk mcp-test - 0.9.0-SNAPSHOT + 0.10.0-SNAPSHOT test diff --git a/mcp-test/pom.xml b/mcp-test/pom.xml index 95f5dc30a..f1484ae77 100644 --- a/mcp-test/pom.xml +++ b/mcp-test/pom.xml @@ -6,7 +6,7 @@ io.modelcontextprotocol.sdk mcp-parent - 0.9.0-SNAPSHOT + 0.10.0-SNAPSHOT mcp-test jar @@ -24,7 +24,7 @@ io.modelcontextprotocol.sdk mcp - 0.9.0-SNAPSHOT + 0.10.0-SNAPSHOT diff --git a/mcp/pom.xml b/mcp/pom.xml index edb1c8f07..6b0f4a9fe 100644 --- a/mcp/pom.xml +++ b/mcp/pom.xml @@ -6,7 +6,7 @@ io.modelcontextprotocol.sdk mcp-parent - 0.9.0-SNAPSHOT + 0.10.0-SNAPSHOT mcp jar diff --git a/pom.xml b/pom.xml index 8e7cca2a9..ff485b75d 100644 --- a/pom.xml +++ b/pom.xml @@ -6,7 +6,7 @@ io.modelcontextprotocol.sdk mcp-parent - 0.9.0-SNAPSHOT + 0.10.0-SNAPSHOT pom https://github.com/modelcontextprotocol/java-sdk From 8068854227cb378e44dcfe2985c5b002b65626e2 Mon Sep 17 00:00:00 2001 From: James Ward Date: Mon, 14 Apr 2025 23:41:41 -0600 Subject: [PATCH 027/205] add access to server instructions (#148) --- .../client/McpAsyncClient.java | 15 +++++++++++++++ .../client/McpSyncClient.java | 9 +++++++++ 2 files changed, 24 insertions(+) diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java index df099836d..1a9c39360 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java @@ -112,6 +112,11 @@ public class McpAsyncClient { */ private McpSchema.ServerCapabilities serverCapabilities; + /** + * Server instructions. + */ + private String serverInstructions; + /** * Server implementation information. */ @@ -240,6 +245,15 @@ public McpSchema.ServerCapabilities getServerCapabilities() { return this.serverCapabilities; } + /** + * Get the server instructions that provide guidance to the client on how to interact + * with this server. + * @return The server instructions + */ + public String getServerInstructions() { + return this.serverInstructions; + } + /** * Get the server implementation information. * @return The server implementation details @@ -328,6 +342,7 @@ public Mono initialize() { return result.flatMap(initializeResult -> { this.serverCapabilities = initializeResult.capabilities(); + this.serverInstructions = initializeResult.instructions(); this.serverInfo = initializeResult.serverInfo(); logger.info("Server response with Protocol: {}, Capabilities: {}, Info: {} and Instructions {}", diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java index 32cf325e9..8544c3637 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java @@ -79,6 +79,15 @@ public McpSchema.ServerCapabilities getServerCapabilities() { return this.delegate.getServerCapabilities(); } + /** + * Get the server instructions that provide guidance to the client on how to interact + * with this server. + * @return The instructions + */ + public String getServerInstructions() { + return this.delegate.getServerInstructions(); + } + /** * Get the server implementation information. * @return The server implementation details From 263e3741b285f92baf87ff166e8d6dc3eafe9124 Mon Sep 17 00:00:00 2001 From: "a.darafeyeu" Date: Wed, 9 Apr 2025 19:43:29 +0200 Subject: [PATCH 028/205] feat(test): Use dynamic port allocation in integration tests (#133) - Add TestUtil class with findAvailablePort() method to the mcp-test module - Add findAvailablePort() method to TomcatTestUtil classes - Replace hardcoded port numbers with dynamic port allocation Signed-off-by: Christian Tzolov --- .../WebFluxSseIntegrationTests.java | 7 +++-- .../server/WebFluxSseMcpAsyncServerTests.java | 2 +- .../server/WebFluxSseMcpSyncServerTests.java | 2 +- .../server/TomcatTestUtil.java | 10 +++++- .../WebMvcSseAsyncServerTransportTests.java | 3 +- .../WebMvcSseCustomContextPathTests.java | 8 ++--- .../server/WebMvcSseIntegrationTests.java | 6 ++-- .../WebMvcSseSyncServerTransportTests.java | 3 +- .../modelcontextprotocol/server/TestUtil.java | 31 +++++++++++++++++++ ...ervletSseServerCustomContextPathTests.java | 7 +++-- ...rverTransportProviderIntegrationTests.java | 9 +++--- .../server/transport/TomcatTestUtil.java | 28 ++++++++++++++--- 12 files changed, 87 insertions(+), 29 deletions(-) create mode 100644 mcp-test/src/main/java/io/modelcontextprotocol/server/TestUtil.java diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java index 76f908b8a..214b97f1b 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java @@ -18,6 +18,7 @@ import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; import io.modelcontextprotocol.server.McpServer; import io.modelcontextprotocol.server.McpServerFeatures; +import io.modelcontextprotocol.server.TestUtil; import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; @@ -50,9 +51,9 @@ import static org.awaitility.Awaitility.await; import static org.mockito.Mockito.mock; -public class WebFluxSseIntegrationTests { +class WebFluxSseIntegrationTests { - private static final int PORT = 8182; + private static final int PORT = TestUtil.findAvailablePort(); private static final String CUSTOM_SSE_ENDPOINT = "/somePath/sse"; @@ -133,7 +134,7 @@ void testCreateMessageWithoutSamplingCapabilities(String clientType) { @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "httpclient", "webflux" }) - void testCreateMessageSuccess(String clientType) throws InterruptedException { + void testCreateMessageSuccess(String clientType) { var clientBuilder = clientBuilders.get(clientType); diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerTests.java index 98844c741..cc33e7b94 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerTests.java @@ -23,7 +23,7 @@ @Timeout(15) // Giving extra time beyond the client timeout class WebFluxSseMcpAsyncServerTests extends AbstractMcpAsyncServerTests { - private static final int PORT = 8181; + private static final int PORT = TestUtil.findAvailablePort(); private static final String MESSAGE_ENDPOINT = "/mcp/message"; diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerTests.java index 71072855e..2fc104538 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerTests.java @@ -23,7 +23,7 @@ @Timeout(15) // Giving extra time beyond the client timeout class WebFluxSseMcpSyncServerTests extends AbstractMcpSyncServerTests { - private static final int PORT = 8182; + private static final int PORT = TestUtil.findAvailablePort(); private static final String MESSAGE_ENDPOINT = "/mcp/message"; diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/TomcatTestUtil.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/TomcatTestUtil.java index fcd7fb4dc..ccf9e2d77 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/TomcatTestUtil.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/TomcatTestUtil.java @@ -3,6 +3,10 @@ */ package io.modelcontextprotocol.server; +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.ServerSocket; + import org.apache.catalina.Context; import org.apache.catalina.startup.Tomcat; @@ -14,10 +18,14 @@ */ public class TomcatTestUtil { + TomcatTestUtil() { + // Prevent instantiation + } + public record TomcatServer(Tomcat tomcat, AnnotationConfigWebApplicationContext appContext) { } - public TomcatServer createTomcatServer(String contextPath, int port, Class componentClass) { + public static TomcatServer createTomcatServer(String contextPath, int port, Class componentClass) { // Set up Tomcat first var tomcat = new Tomcat(); diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportTests.java index 08d5de671..6a6ad17e9 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportTests.java @@ -25,7 +25,7 @@ class WebMvcSseAsyncServerTransportTests extends AbstractMcpAsyncServerTests { private static final String MESSAGE_ENDPOINT = "/mcp/message"; - private static final int PORT = 8181; + private static final int PORT = TestUtil.findAvailablePort(); private Tomcat tomcat; @@ -73,7 +73,6 @@ protected McpServerTransportProvider createMcpTransportProvider() { // Create DispatcherServlet with our Spring context DispatcherServlet dispatcherServlet = new DispatcherServlet(appContext); - // dispatcherServlet.setThrowExceptionIfNoHandlerFound(true); // Add servlet to Tomcat and get the wrapper var wrapper = Tomcat.addServlet(context, "dispatcherServlet", dispatcherServlet); diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomContextPathTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomContextPathTests.java index 0e81104b9..1b5218cc5 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomContextPathTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomContextPathTests.java @@ -22,11 +22,11 @@ import static org.assertj.core.api.Assertions.assertThat; -public class WebMvcSseCustomContextPathTests { +class WebMvcSseCustomContextPathTests { private static final String CUSTOM_CONTEXT_PATH = "/app/1"; - private static final int PORT = 8183; + private static final int PORT = TestUtil.findAvailablePort(); private static final String MESSAGE_ENDPOINT = "/mcp/message"; @@ -39,11 +39,11 @@ public class WebMvcSseCustomContextPathTests { @BeforeEach public void before() { - tomcatServer = new TomcatTestUtil().createTomcatServer(CUSTOM_CONTEXT_PATH, PORT, TestConfig.class); + tomcatServer = TomcatTestUtil.createTomcatServer(CUSTOM_CONTEXT_PATH, PORT, TestConfig.class); try { tomcatServer.tomcat().start(); - assertThat(tomcatServer.tomcat().getServer().getState() == LifecycleState.STARTED); + assertThat(tomcatServer.tomcat().getServer().getState()).isEqualTo(LifecycleState.STARTED); } catch (Exception e) { throw new RuntimeException("Failed to start Tomcat", e); diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java index be01365a1..df527f87e 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java @@ -46,7 +46,7 @@ class WebMvcSseIntegrationTests { - private static final int PORT = 8183; + private static final int PORT = TestUtil.findAvailablePort(); private static final String MESSAGE_ENDPOINT = "/mcp/message"; @@ -75,7 +75,7 @@ public RouterFunction routerFunction(WebMvcSseServerTransportPro @BeforeEach public void before() { - tomcatServer = new TomcatTestUtil().createTomcatServer("", PORT, TestConfig.class); + tomcatServer = TomcatTestUtil.createTomcatServer("", PORT, TestConfig.class); try { tomcatServer.tomcat().start(); @@ -151,7 +151,7 @@ void testCreateMessageWithoutSamplingCapabilities() { } @Test - void testCreateMessageSuccess() throws InterruptedException { + void testCreateMessageSuccess() { Function samplingHandler = request -> { assertThat(request.messages()).hasSize(1); diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseSyncServerTransportTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseSyncServerTransportTests.java index b85bed379..1964703c1 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseSyncServerTransportTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseSyncServerTransportTests.java @@ -24,7 +24,7 @@ class WebMvcSseSyncServerTransportTests extends AbstractMcpSyncServerTests { private static final String MESSAGE_ENDPOINT = "/mcp/message"; - private static final int PORT = 8181; + private static final int PORT = TestUtil.findAvailablePort(); private Tomcat tomcat; @@ -72,7 +72,6 @@ protected WebMvcSseServerTransportProvider createMcpTransportProvider() { // Create DispatcherServlet with our Spring context DispatcherServlet dispatcherServlet = new DispatcherServlet(appContext); - // dispatcherServlet.setThrowExceptionIfNoHandlerFound(true); // Add servlet to Tomcat and get the wrapper var wrapper = Tomcat.addServlet(context, "dispatcherServlet", dispatcherServlet); diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/server/TestUtil.java b/mcp-test/src/main/java/io/modelcontextprotocol/server/TestUtil.java new file mode 100644 index 000000000..0085f31ed --- /dev/null +++ b/mcp-test/src/main/java/io/modelcontextprotocol/server/TestUtil.java @@ -0,0 +1,31 @@ +/* +* Copyright 2025 - 2025 the original author or authors. +*/ +package io.modelcontextprotocol.server; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.ServerSocket; + +public class TestUtil { + + TestUtil() { + // Prevent instantiation + } + + /** + * Finds an available port on the local machine. + * @return an available port number + * @throws IllegalStateException if no available port can be found + */ + public static int findAvailablePort() { + try (final ServerSocket socket = new ServerSocket()) { + socket.bind(new InetSocketAddress(0)); + return socket.getLocalPort(); + } + catch (final IOException e) { + throw new IllegalStateException("Cannot bind to an available port!", e); + } + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerCustomContextPathTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerCustomContextPathTests.java index 212a3c95d..2cd62889a 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerCustomContextPathTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerCustomContextPathTests.java @@ -4,6 +4,7 @@ package io.modelcontextprotocol.server.transport; import com.fasterxml.jackson.databind.ObjectMapper; + import io.modelcontextprotocol.client.McpClient; import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; import io.modelcontextprotocol.server.McpServer; @@ -17,9 +18,9 @@ import static org.assertj.core.api.Assertions.assertThat; -public class HttpServletSseServerCustomContextPathTests { +class HttpServletSseServerCustomContextPathTests { - private static final int PORT = 8195; + private static final int PORT = TomcatTestUtil.findAvailablePort(); private static final String CUSTOM_CONTEXT_PATH = "/api/v1"; @@ -48,7 +49,7 @@ public void before() { try { tomcat.start(); - assertThat(tomcat.getServer().getState() == LifecycleState.STARTED); + assertThat(tomcat.getServer().getState()).isEqualTo(LifecycleState.STARTED); } catch (Exception e) { throw new RuntimeException("Failed to start Tomcat", e); diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java index 135de83fa..f25ce5678 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java @@ -12,6 +12,7 @@ import java.util.stream.Collectors; import com.fasterxml.jackson.databind.ObjectMapper; + import io.modelcontextprotocol.client.McpClient; import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; import io.modelcontextprotocol.server.McpServer; @@ -44,9 +45,9 @@ import static org.awaitility.Awaitility.await; import static org.mockito.Mockito.mock; -public class HttpServletSseServerTransportProviderIntegrationTests { +class HttpServletSseServerTransportProviderIntegrationTests { - private static final int PORT = 8189; + private static final int PORT = TomcatTestUtil.findAvailablePort(); private static final String CUSTOM_SSE_ENDPOINT = "/somePath/sse"; @@ -70,7 +71,7 @@ public void before() { tomcat = TomcatTestUtil.createTomcatServer("", PORT, mcpServerTransportProvider); try { tomcat.start(); - assertThat(tomcat.getServer().getState() == LifecycleState.STARTED); + assertThat(tomcat.getServer().getState()).isEqualTo(LifecycleState.STARTED); } catch (Exception e) { throw new RuntimeException("Failed to start Tomcat", e); @@ -133,7 +134,7 @@ void testCreateMessageWithoutSamplingCapabilities() { } @Test - void testCreateMessageSuccess() throws InterruptedException { + void testCreateMessageSuccess() { Function samplingHandler = request -> { assertThat(request.messages()).hasSize(1); diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/TomcatTestUtil.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/TomcatTestUtil.java index 6f922dfa6..f61cdc413 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/TomcatTestUtil.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/TomcatTestUtil.java @@ -3,19 +3,23 @@ */ package io.modelcontextprotocol.server.transport; -import com.fasterxml.jackson.databind.ObjectMapper; +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.ServerSocket; + import jakarta.servlet.Servlet; import org.apache.catalina.Context; -import org.apache.catalina.LifecycleState; import org.apache.catalina.startup.Tomcat; -import static org.junit.Assert.assertThat; - /** * @author Christian Tzolov */ public class TomcatTestUtil { + TomcatTestUtil() { + // Prevent instantiation + } + public static Tomcat createTomcatServer(String contextPath, int port, Servlet servlet) { var tomcat = new Tomcat(); @@ -24,7 +28,6 @@ public static Tomcat createTomcatServer(String contextPath, int port, Servlet se String baseDir = System.getProperty("java.io.tmpdir"); tomcat.setBaseDir(baseDir); - // Context context = tomcat.addContext("", baseDir); Context context = tomcat.addContext(contextPath, baseDir); // Add transport servlet to Tomcat @@ -42,4 +45,19 @@ public static Tomcat createTomcatServer(String contextPath, int port, Servlet se return tomcat; } + /** + * Finds an available port on the local machine. + * @return an available port number + * @throws IllegalStateException if no available port can be found + */ + public static int findAvailablePort() { + try (final ServerSocket socket = new ServerSocket()) { + socket.bind(new InetSocketAddress(0)); + return socket.getLocalPort(); + } + catch (final IOException e) { + throw new IllegalStateException("Cannot bind to an available port!", e); + } + } + } From 472f07ec6da7a2233d0d73b2be8ff7b6f9a85233 Mon Sep 17 00:00:00 2001 From: mipengcheng3 Date: Tue, 8 Apr 2025 17:53:59 +0800 Subject: [PATCH 029/205] feat(mcp): add configurable request timeout to MCP server (#134) Adds the ability to configure request timeouts for MCP server operations. This enhancement allows setting a custom duration to wait for server responses before timing out requests, which applies to all requests made through the client including tool calls, resource access, and prompt operations. - Add requestTimeout parameter to McpServerSession constructor - Add requestTimeout field and builder method to server classes - Pass timeout configuration through to session creation - Add tests for both success and failure scenarios across different transport implementations - Default timeout is set to 10 seconds if not explicitly configured. --- .../WebFluxSseIntegrationTests.java | 149 ++++++++++++++++++ .../server/WebMvcSseIntegrationTests.java | 145 +++++++++++++++++ .../server/McpAsyncServer.java | 13 +- .../server/McpServer.java | 39 ++++- .../spec/McpServerSession.java | 12 +- ...rverTransportProviderIntegrationTests.java | 145 +++++++++++++++++ 6 files changed, 491 insertions(+), 12 deletions(-) diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java index 214b97f1b..dab54376e 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java @@ -8,6 +8,7 @@ import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; import java.util.stream.Collectors; @@ -48,6 +49,7 @@ import org.springframework.web.reactive.function.server.RouterFunctions; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.awaitility.Awaitility.await; import static org.mockito.Mockito.mock; @@ -196,6 +198,153 @@ void testCreateMessageSuccess(String clientType) { mcpServer.close(); } + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateMessageWithRequestTimeoutSuccess(String clientType) throws InterruptedException { + + // Client + var clientBuilder = clientBuilders.get(clientType); + + Function samplingHandler = request -> { + assertThat(request.messages()).hasSize(1); + assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); + try { + TimeUnit.SECONDS.sleep(2); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", + CreateMessageResult.StopReason.STOP_SEQUENCE); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build(); + + // Server + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("Test message")))) + .modelPreferences(ModelPreferences.builder() + .hints(List.of()) + .costPriority(1.0) + .speedPriority(1.0) + .intelligencePriority(1.0) + .build()) + .build(); + + StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .requestTimeout(Duration.ofSeconds(4)) + .serverInfo("test-server", "1.0.0") + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + mcpClient.close(); + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateMessageWithRequestTimeoutFail(String clientType) throws InterruptedException { + + // Client + var clientBuilder = clientBuilders.get(clientType); + + Function samplingHandler = request -> { + assertThat(request.messages()).hasSize(1); + assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); + try { + TimeUnit.SECONDS.sleep(3); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", + CreateMessageResult.StopReason.STOP_SEQUENCE); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build(); + + // Server + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("Test message")))) + .modelPreferences(ModelPreferences.builder() + .hints(List.of()) + .costPriority(1.0) + .speedPriority(1.0) + .intelligencePriority(1.0) + .build()) + .build(); + + StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .requestTimeout(Duration.ofSeconds(1)) + .serverInfo("test-server", "1.0.0") + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThatExceptionOfType(McpError.class).isThrownBy(() -> { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + }).withMessageContaining("Timeout"); + + mcpClient.close(); + mcpServer.close(); + } + // --------------------------------------- // Roots Tests // --------------------------------------- diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java index df527f87e..07b36c25b 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java @@ -6,6 +6,7 @@ import java.time.Duration; import java.util.List; import java.util.Map; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; @@ -41,6 +42,7 @@ import org.springframework.web.servlet.function.ServerResponse; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.awaitility.Awaitility.await; import static org.mockito.Mockito.mock; @@ -212,6 +214,149 @@ void testCreateMessageSuccess() { mcpServer.close(); } + @Test + void testCreateMessageWithRequestTimeoutSuccess() throws InterruptedException { + + // Client + + Function samplingHandler = request -> { + assertThat(request.messages()).hasSize(1); + assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); + try { + TimeUnit.SECONDS.sleep(2); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", + CreateMessageResult.StopReason.STOP_SEQUENCE); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build(); + + // Server + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("Test message")))) + .modelPreferences(ModelPreferences.builder() + .hints(List.of()) + .costPriority(1.0) + .speedPriority(1.0) + .intelligencePriority(1.0) + .build()) + .build(); + + StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(4)) + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + mcpClient.close(); + mcpServer.close(); + } + + @Test + void testCreateMessageWithRequestTimeoutFail() throws InterruptedException { + + // Client + + Function samplingHandler = request -> { + assertThat(request.messages()).hasSize(1); + assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); + try { + TimeUnit.SECONDS.sleep(2); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", + CreateMessageResult.StopReason.STOP_SEQUENCE); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build(); + + // Server + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("Test message")))) + .modelPreferences(ModelPreferences.builder() + .hints(List.of()) + .costPriority(1.0) + .speedPriority(1.0) + .intelligencePriority(1.0) + .build()) + .build(); + + StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(1)) + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThatExceptionOfType(McpError.class).isThrownBy(() -> { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + }).withMessageContaining("Timeout"); + + mcpClient.close(); + mcpServer.close(); + } + // --------------------------------------- // Roots Tests // --------------------------------------- diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java index 062de13ed..4f7d0e87e 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -4,6 +4,7 @@ package io.modelcontextprotocol.server; +import java.time.Duration; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -90,8 +91,8 @@ public class McpAsyncServer { * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization */ McpAsyncServer(McpServerTransportProvider mcpTransportProvider, ObjectMapper objectMapper, - McpServerFeatures.Async features) { - this.delegate = new AsyncServerImpl(mcpTransportProvider, objectMapper, features); + McpServerFeatures.Async features, Duration requestTimeout) { + this.delegate = new AsyncServerImpl(mcpTransportProvider, objectMapper, requestTimeout, features); } /** @@ -271,7 +272,7 @@ private static class AsyncServerImpl extends McpAsyncServer { private List protocolVersions = List.of(McpSchema.LATEST_PROTOCOL_VERSION); AsyncServerImpl(McpServerTransportProvider mcpTransportProvider, ObjectMapper objectMapper, - McpServerFeatures.Async features) { + Duration requestTimeout, McpServerFeatures.Async features) { this.mcpTransportProvider = mcpTransportProvider; this.objectMapper = objectMapper; this.serverInfo = features.serverInfo(); @@ -330,9 +331,9 @@ private static class AsyncServerImpl extends McpAsyncServer { notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED, asyncRootsListChangedNotificationHandler(rootsChangeConsumers)); - mcpTransportProvider - .setSessionFactory(transport -> new McpServerSession(UUID.randomUUID().toString(), transport, - this::asyncInitializeRequestHandler, Mono::empty, requestHandlers, notificationHandlers)); + mcpTransportProvider.setSessionFactory( + transport -> new McpServerSession(UUID.randomUUID().toString(), requestTimeout, transport, + this::asyncInitializeRequestHandler, Mono::empty, requestHandlers, notificationHandlers)); } // --------------------------------------- diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java index d5427335d..60434a841 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java @@ -4,6 +4,7 @@ package io.modelcontextprotocol.server; +import java.time.Duration; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; @@ -193,11 +194,28 @@ class AsyncSpecification { private final List, Mono>> rootsChangeHandlers = new ArrayList<>(); + private Duration requestTimeout = Duration.ofSeconds(10); // Default timeout + private AsyncSpecification(McpServerTransportProvider transportProvider) { Assert.notNull(transportProvider, "Transport provider must not be null"); this.transportProvider = transportProvider; } + /** + * Sets the duration to wait for server responses before timing out requests. This + * timeout applies to all requests made through the client, including tool calls, + * resource access, and prompt operations. + * @param requestTimeout The duration to wait before timing out requests. Must not + * be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if requestTimeout is null + */ + public AsyncSpecification requestTimeout(Duration requestTimeout) { + Assert.notNull(requestTimeout, "Request timeout must not be null"); + this.requestTimeout = requestTimeout; + return this; + } + /** * Sets the server implementation information that will be shared with clients * during connection initialization. This helps with version compatibility, @@ -565,7 +583,7 @@ public McpAsyncServer build() { var features = new McpServerFeatures.Async(this.serverInfo, this.serverCapabilities, this.tools, this.resources, this.resourceTemplates, this.prompts, this.rootsChangeHandlers, this.instructions); var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); - return new McpAsyncServer(this.transportProvider, mapper, features); + return new McpAsyncServer(this.transportProvider, mapper, features, this.requestTimeout); } } @@ -619,11 +637,28 @@ class SyncSpecification { private final List>> rootsChangeHandlers = new ArrayList<>(); + private Duration requestTimeout = Duration.ofSeconds(10); // Default timeout + private SyncSpecification(McpServerTransportProvider transportProvider) { Assert.notNull(transportProvider, "Transport provider must not be null"); this.transportProvider = transportProvider; } + /** + * Sets the duration to wait for server responses before timing out requests. This + * timeout applies to all requests made through the client, including tool calls, + * resource access, and prompt operations. + * @param requestTimeout The duration to wait before timing out requests. Must not + * be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if requestTimeout is null + */ + public SyncSpecification requestTimeout(Duration requestTimeout) { + Assert.notNull(requestTimeout, "Request timeout must not be null"); + this.requestTimeout = requestTimeout; + return this; + } + /** * Sets the server implementation information that will be shared with clients * during connection initialization. This helps with version compatibility, @@ -992,7 +1027,7 @@ public McpSyncServer build() { this.instructions); McpServerFeatures.Async asyncFeatures = McpServerFeatures.Async.fromSync(syncFeatures); var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); - var asyncServer = new McpAsyncServer(this.transportProvider, mapper, asyncFeatures); + var asyncServer = new McpAsyncServer(this.transportProvider, mapper, asyncFeatures, this.requestTimeout); return new McpSyncServer(asyncServer); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java index 46014af8d..46c356cdd 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java @@ -27,6 +27,9 @@ public class McpServerSession implements McpSession { private final String id; + /** Duration to wait for request responses before timing out */ + private final Duration requestTimeout; + private final AtomicLong requestCounter = new AtomicLong(0); private final InitRequestHandler initRequestHandler; @@ -65,10 +68,11 @@ public class McpServerSession implements McpSession { * @param requestHandlers map of request handlers to use * @param notificationHandlers map of notification handlers to use */ - public McpServerSession(String id, McpServerTransport transport, InitRequestHandler initHandler, - InitNotificationHandler initNotificationHandler, Map> requestHandlers, - Map notificationHandlers) { + public McpServerSession(String id, Duration requestTimeout, McpServerTransport transport, + InitRequestHandler initHandler, InitNotificationHandler initNotificationHandler, + Map> requestHandlers, Map notificationHandlers) { this.id = id; + this.requestTimeout = requestTimeout; this.transport = transport; this.initRequestHandler = initHandler; this.initNotificationHandler = initNotificationHandler; @@ -116,7 +120,7 @@ public Mono sendRequest(String method, Object requestParams, TypeReferenc this.pendingResponses.remove(requestId); sink.error(error); }); - }).timeout(Duration.ofSeconds(10)).handle((jsonRpcResponse, sink) -> { + }).timeout(requestTimeout).handle((jsonRpcResponse, sink) -> { if (jsonRpcResponse.error() != null) { sink.error(new McpError(jsonRpcResponse.error())); } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java index f25ce5678..b8f040c7a 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java @@ -7,6 +7,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; import java.util.stream.Collectors; @@ -42,6 +43,7 @@ import org.springframework.web.client.RestClient; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.awaitility.Awaitility.await; import static org.mockito.Mockito.mock; @@ -194,6 +196,149 @@ void testCreateMessageSuccess() { mcpServer.close(); } + @Test + void testCreateMessageWithRequestTimeoutSuccess() throws InterruptedException { + + // Client + + Function samplingHandler = request -> { + assertThat(request.messages()).hasSize(1); + assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); + try { + TimeUnit.SECONDS.sleep(2); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", + CreateMessageResult.StopReason.STOP_SEQUENCE); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build(); + + // Server + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("Test message")))) + .modelPreferences(ModelPreferences.builder() + .hints(List.of()) + .costPriority(1.0) + .speedPriority(1.0) + .intelligencePriority(1.0) + .build()) + .build(); + + StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(3)) + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + mcpClient.close(); + mcpServer.close(); + } + + @Test + void testCreateMessageWithRequestTimeoutFail() throws InterruptedException { + + // Client + + Function samplingHandler = request -> { + assertThat(request.messages()).hasSize(1); + assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); + try { + TimeUnit.SECONDS.sleep(2); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", + CreateMessageResult.StopReason.STOP_SEQUENCE); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build(); + + // Server + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("Test message")))) + .modelPreferences(ModelPreferences.builder() + .hints(List.of()) + .costPriority(1.0) + .speedPriority(1.0) + .intelligencePriority(1.0) + .build()) + .build(); + + StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(1)) + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThatExceptionOfType(McpError.class).isThrownBy(() -> { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + }).withMessageContaining("Timeout"); + + mcpClient.close(); + mcpServer.close(); + } + // --------------------------------------- // Roots Tests // --------------------------------------- From f7f8ccd0acb6d39558b65ceb1ae4e5f71619a37c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Mon, 14 Apr 2025 12:08:30 +0200 Subject: [PATCH 030/205] Fix flaky test running blocking code in event loop (#155) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Replace StepVerifier with assertWith for cleaner test assertions - Add try-with-resources blocks for proper client resource management - Use closeGracefully().block() for proper server shutdown Signed-off-by: Dariusz Jędrzejczyk --- .../WebFluxSseIntegrationTests.java | 135 ++++++++---------- 1 file changed, 62 insertions(+), 73 deletions(-) diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java index dab54376e..6ba0911ed 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java @@ -37,10 +37,8 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; -import reactor.core.publisher.Mono; import reactor.netty.DisposableServer; import reactor.netty.http.server.HttpServer; -import reactor.test.StepVerifier; import org.springframework.http.server.reactive.HttpHandler; import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; @@ -50,6 +48,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertWith; import static org.awaitility.Awaitility.await; import static org.mockito.Mockito.mock; @@ -109,12 +108,9 @@ void testCreateMessageWithoutSamplingCapabilities(String clientType) { var clientBuilder = clientBuilders.get(clientType); McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { - - exchange.createMessage(mock(McpSchema.CreateMessageRequest.class)).block(); - - return Mono.just(mock(CallToolResult.class)); - }); + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), + (exchange, request) -> exchange.createMessage(mock(CreateMessageRequest.class)) + .thenReturn(mock(CallToolResult.class))); var server = McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").tools(tool).build(); @@ -151,6 +147,8 @@ void testCreateMessageSuccess(String clientType) { CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + AtomicReference samplingResult = new AtomicReference<>(); + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { @@ -165,16 +163,9 @@ void testCreateMessageSuccess(String clientType) { .build()) .build(); - StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.role()).isEqualTo(Role.USER); - assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); - assertThat(result.model()).isEqualTo("MockModelName"); - assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); - }).verifyComplete(); - - return Mono.just(callResponse); + return exchange.createMessage(craeteMessageRequest) + .doOnNext(samplingResult::set) + .thenReturn(callResponse); }); var mcpServer = McpServer.async(mcpServerTransportProvider) @@ -194,8 +185,17 @@ void testCreateMessageSuccess(String clientType) { assertThat(response).isNotNull(); assertThat(response).isEqualTo(callResponse); + + assertWith(samplingResult.get(), result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }); } - mcpServer.close(); + mcpServer.closeGracefully().block(); } @ParameterizedTest(name = "{0} : {displayName} ") @@ -218,16 +218,13 @@ void testCreateMessageWithRequestTimeoutSuccess(String clientType) throws Interr CreateMessageResult.StopReason.STOP_SEQUENCE); }; - var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().sampling().build()) - .sampling(samplingHandler) - .build(); - // Server CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + AtomicReference samplingResult = new AtomicReference<>(); + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { @@ -242,16 +239,9 @@ void testCreateMessageWithRequestTimeoutSuccess(String clientType) throws Interr .build()) .build(); - StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.role()).isEqualTo(Role.USER); - assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); - assertThat(result.model()).isEqualTo("MockModelName"); - assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); - }).verifyComplete(); - - return Mono.just(callResponse); + return exchange.createMessage(craeteMessageRequest) + .doOnNext(samplingResult::set) + .thenReturn(callResponse); }); var mcpServer = McpServer.async(mcpServerTransportProvider) @@ -260,16 +250,30 @@ void testCreateMessageWithRequestTimeoutSuccess(String clientType) throws Interr .tools(tool) .build(); - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build()) { - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - mcpClient.close(); - mcpServer.close(); + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + assertWith(samplingResult.get(), result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }); + } + + mcpServer.closeGracefully().block(); } @ParameterizedTest(name = "{0} : {displayName} ") @@ -283,7 +287,7 @@ void testCreateMessageWithRequestTimeoutFail(String clientType) throws Interrupt assertThat(request.messages()).hasSize(1); assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); try { - TimeUnit.SECONDS.sleep(3); + TimeUnit.SECONDS.sleep(2); } catch (InterruptedException e) { throw new RuntimeException(e); @@ -292,11 +296,6 @@ void testCreateMessageWithRequestTimeoutFail(String clientType) throws Interrupt CreateMessageResult.StopReason.STOP_SEQUENCE); }; - var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().sampling().build()) - .sampling(samplingHandler) - .build(); - // Server CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), @@ -308,24 +307,9 @@ void testCreateMessageWithRequestTimeoutFail(String clientType) throws Interrupt var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Test message")))) - .modelPreferences(ModelPreferences.builder() - .hints(List.of()) - .costPriority(1.0) - .speedPriority(1.0) - .intelligencePriority(1.0) - .build()) .build(); - StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.role()).isEqualTo(Role.USER); - assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); - assertThat(result.model()).isEqualTo("MockModelName"); - assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); - }).verifyComplete(); - - return Mono.just(callResponse); + return exchange.createMessage(craeteMessageRequest).thenReturn(callResponse); }); var mcpServer = McpServer.async(mcpServerTransportProvider) @@ -334,15 +318,21 @@ void testCreateMessageWithRequestTimeoutFail(String clientType) throws Interrupt .tools(tool) .build(); - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build()) { - assertThatExceptionOfType(McpError.class).isThrownBy(() -> { - mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - }).withMessageContaining("Timeout"); + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); - mcpClient.close(); - mcpServer.close(); + assertThatExceptionOfType(McpError.class).isThrownBy(() -> { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + }).withMessageContaining("within 1000ms"); + + } + + mcpServer.closeGracefully().block(); } // --------------------------------------- @@ -412,9 +402,8 @@ void testRootsWithoutCapability(String clientType) { var mcpServer = McpServer.sync(mcpServerTransportProvider).rootsChangeHandler((exchange, rootsUpdate) -> { }).tools(tool).build(); - try ( - // Create client without roots capability - var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()).build()) { + // Create client without roots capability + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()).build()) { assertThat(mcpClient.initialize()).isNotNull(); From 84adde16a539e2e482bf4deb2c6c3b17b3c148fd Mon Sep 17 00:00:00 2001 From: mackey0225 Date: Tue, 15 Apr 2025 10:48:19 +0900 Subject: [PATCH 031/205] fix: correct typos in variable names, method names, and commentsj (#159) --- .../WebFluxSseServerTransportProvider.java | 2 +- .../WebFluxSseIntegrationTests.java | 6 +++--- .../server/WebMvcSseIntegrationTests.java | 6 +++--- .../server/AbstractMcpAsyncServerTests.java | 2 +- .../server/AbstractMcpSyncServerTests.java | 16 ++++++++-------- .../modelcontextprotocol/client/McpClient.java | 4 ++-- .../server/McpAsyncServer.java | 2 +- .../io/modelcontextprotocol/spec/McpSchema.java | 2 +- .../server/AbstractMcpAsyncServerTests.java | 2 +- .../server/AbstractMcpSyncServerTests.java | 16 ++++++++-------- ...eServerTransportProviderIntegrationTests.java | 6 +++--- 11 files changed, 32 insertions(+), 32 deletions(-) diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java index eed8a53af..62264d9aa 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java @@ -141,7 +141,7 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messa * Constructs a new WebFlux SSE server transport provider instance. * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization * of MCP messages. Must not be null. - * @param baseUrl webflux messag base path + * @param baseUrl webflux message base path * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC * messages. This endpoint will be communicated to clients during SSE connection * setup. Must not be null. diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java index 6ba0911ed..80a126441 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java @@ -152,7 +152,7 @@ void testCreateMessageSuccess(String clientType) { McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { - var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() + var createMessageRequest = McpSchema.CreateMessageRequest.builder() .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Test message")))) .modelPreferences(ModelPreferences.builder() @@ -163,7 +163,7 @@ void testCreateMessageSuccess(String clientType) { .build()) .build(); - return exchange.createMessage(craeteMessageRequest) + return exchange.createMessage(createMessageRequest) .doOnNext(samplingResult::set) .thenReturn(callResponse); }); @@ -421,7 +421,7 @@ void testRootsWithoutCapability(String clientType) { @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "httpclient", "webflux" }) - void testRootsNotifciationWithEmptyRootsList(String clientType) { + void testRootsNotificationWithEmptyRootsList(String clientType) { var clientBuilder = clientBuilders.get(clientType); AtomicReference> rootsRef = new AtomicReference<>(); diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java index 07b36c25b..b12d68439 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java @@ -169,7 +169,7 @@ void testCreateMessageSuccess() { McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { - var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() + var createMessageRequest = McpSchema.CreateMessageRequest.builder() .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Test message")))) .modelPreferences(ModelPreferences.builder() @@ -180,7 +180,7 @@ void testCreateMessageSuccess() { .build()) .build(); - StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { + StepVerifier.create(exchange.createMessage(createMessageRequest)).consumeNextWith(result -> { assertThat(result).isNotNull(); assertThat(result.role()).isEqualTo(Role.USER); assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); @@ -438,7 +438,7 @@ void testRootsWithoutCapability() { } @Test - void testRootsNotifciationWithEmptyRootsList() { + void testRootsNotificationWithEmptyRootsList() { AtomicReference> rootsRef = new AtomicReference<>(); var mcpServer = McpServer.sync(mcpServerTransportProvider) diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java index a91632c6c..cdd43e7ef 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java @@ -110,7 +110,7 @@ void testAddTool() { .build(); StepVerifier.create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolSpecification(newTool, - (excnage, args) -> Mono.just(new CallToolResult(List.of(), false))))) + (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))))) .verifyComplete(); assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java index 9a63143c9..c81e638c1 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java @@ -200,16 +200,16 @@ void testAddResource() { Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", null); - McpServerFeatures.SyncResourceSpecification specificaiton = new McpServerFeatures.SyncResourceSpecification( + McpServerFeatures.SyncResourceSpecification specification = new McpServerFeatures.SyncResourceSpecification( resource, (exchange, req) -> new ReadResourceResult(List.of())); - assertThatCode(() -> mcpSyncServer.addResource(specificaiton)).doesNotThrowAnyException(); + assertThatCode(() -> mcpSyncServer.addResource(specification)).doesNotThrowAnyException(); assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); } @Test - void testAddResourceWithNullSpecifiation() { + void testAddResourceWithNullSpecification() { var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().resources(true, false).build()) @@ -279,11 +279,11 @@ void testAddPromptWithoutCapability() { .build(); Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", List.of()); - McpServerFeatures.SyncPromptSpecification specificaiton = new McpServerFeatures.SyncPromptSpecification(prompt, + McpServerFeatures.SyncPromptSpecification specification = new McpServerFeatures.SyncPromptSpecification(prompt, (exchange, req) -> new GetPromptResult("Test prompt description", List .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content"))))); - assertThatThrownBy(() -> serverWithoutPrompts.addPrompt(specificaiton)).isInstanceOf(McpError.class) + assertThatThrownBy(() -> serverWithoutPrompts.addPrompt(specification)).isInstanceOf(McpError.class) .hasMessage("Server must be configured with prompt capabilities"); } @@ -300,14 +300,14 @@ void testRemovePromptWithoutCapability() { @Test void testRemovePrompt() { Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", List.of()); - McpServerFeatures.SyncPromptSpecification specificaiton = new McpServerFeatures.SyncPromptSpecification(prompt, + McpServerFeatures.SyncPromptSpecification specification = new McpServerFeatures.SyncPromptSpecification(prompt, (exchange, req) -> new GetPromptResult("Test prompt description", List .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content"))))); var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(true).build()) - .prompts(specificaiton) + .prompts(specification) .build(); assertThatCode(() -> mcpSyncServer.removePrompt(TEST_PROMPT_NAME)).doesNotThrowAnyException(); @@ -340,7 +340,7 @@ void testRootsChangeHandlers() { var singleConsumerServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") - .rootsChangeHandlers(List.of((exchage, roots) -> { + .rootsChangeHandlers(List.of((exchange, roots) -> { consumerCalled[0] = true; if (!roots.isEmpty()) { rootsReceived[0] = roots.get(0); diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java index f7b179616..a1dc11685 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java @@ -196,7 +196,7 @@ public SyncSpec requestTimeout(Duration requestTimeout) { } /** - * @param initializationTimeout The duration to wait for the initializaiton + * @param initializationTimeout The duration to wait for the initialization * lifecycle step to complete. * @return This builder instance for method chaining * @throws IllegalArgumentException if initializationTimeout is null @@ -435,7 +435,7 @@ public AsyncSpec requestTimeout(Duration requestTimeout) { } /** - * @param initializationTimeout The duration to wait for the initializaiton + * @param initializationTimeout The duration to wait for the initialization * lifecycle step to complete. * @return This builder instance for method chaining * @throws IllegalArgumentException if initializationTimeout is null diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java index 4f7d0e87e..28b63cecb 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -359,7 +359,7 @@ private Mono asyncInitializeRequestHandler( } else { logger.warn( - "Client requested unsupported protocol version: {}, so the server will sugggest the {} version instead", + "Client requested unsupported protocol version: {}, so the server will suggest the {} version instead", initializeRequest.protocolVersion(), serverProtocolVersion); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java index e621ac19b..6eb5159f6 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java @@ -1103,7 +1103,7 @@ public record ProgressNotification(// @formatter:off * setting minimum log levels, with servers sending notifications containing severity * levels, optional logger names, and arbitrary JSON-serializable data. * - * @param level The severity levels. The mimimum log level is set by the client. + * @param level The severity levels. The minimum log level is set by the client. * @param logger The logger that generated the message. * @param data JSON-serializable logging data. */ diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java index c7c69b52b..df0b0c729 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java @@ -109,7 +109,7 @@ void testAddTool() { .build(); StepVerifier.create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolSpecification(newTool, - (excnage, args) -> Mono.just(new CallToolResult(List.of(), false))))) + (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))))) .verifyComplete(); assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java index 8c9328cc7..0b38da857 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java @@ -199,16 +199,16 @@ void testAddResource() { Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", null); - McpServerFeatures.SyncResourceSpecification specificaiton = new McpServerFeatures.SyncResourceSpecification( + McpServerFeatures.SyncResourceSpecification specification = new McpServerFeatures.SyncResourceSpecification( resource, (exchange, req) -> new ReadResourceResult(List.of())); - assertThatCode(() -> mcpSyncServer.addResource(specificaiton)).doesNotThrowAnyException(); + assertThatCode(() -> mcpSyncServer.addResource(specification)).doesNotThrowAnyException(); assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); } @Test - void testAddResourceWithNullSpecifiation() { + void testAddResourceWithNullSpecification() { var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().resources(true, false).build()) @@ -278,11 +278,11 @@ void testAddPromptWithoutCapability() { .build(); Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", List.of()); - McpServerFeatures.SyncPromptSpecification specificaiton = new McpServerFeatures.SyncPromptSpecification(prompt, + McpServerFeatures.SyncPromptSpecification specification = new McpServerFeatures.SyncPromptSpecification(prompt, (exchange, req) -> new GetPromptResult("Test prompt description", List .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content"))))); - assertThatThrownBy(() -> serverWithoutPrompts.addPrompt(specificaiton)).isInstanceOf(McpError.class) + assertThatThrownBy(() -> serverWithoutPrompts.addPrompt(specification)).isInstanceOf(McpError.class) .hasMessage("Server must be configured with prompt capabilities"); } @@ -299,14 +299,14 @@ void testRemovePromptWithoutCapability() { @Test void testRemovePrompt() { Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", List.of()); - McpServerFeatures.SyncPromptSpecification specificaiton = new McpServerFeatures.SyncPromptSpecification(prompt, + McpServerFeatures.SyncPromptSpecification specification = new McpServerFeatures.SyncPromptSpecification(prompt, (exchange, req) -> new GetPromptResult("Test prompt description", List .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content"))))); var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(true).build()) - .prompts(specificaiton) + .prompts(specification) .build(); assertThatCode(() -> mcpSyncServer.removePrompt(TEST_PROMPT_NAME)).doesNotThrowAnyException(); @@ -339,7 +339,7 @@ void testRootsChangeHandlers() { var singleConsumerServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") - .rootsChangeHandlers(List.of((exchage, roots) -> { + .rootsChangeHandlers(List.of((exchange, roots) -> { consumerCalled[0] = true; if (!roots.isEmpty()) { rootsReceived[0] = roots.get(0); diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java index b8f040c7a..2ff6325a4 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java @@ -152,7 +152,7 @@ void testCreateMessageSuccess() { McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { - var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() + var createMessageRequest = McpSchema.CreateMessageRequest.builder() .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Test message")))) .modelPreferences(ModelPreferences.builder() @@ -163,7 +163,7 @@ void testCreateMessageSuccess() { .build()) .build(); - StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { + StepVerifier.create(exchange.createMessage(createMessageRequest)).consumeNextWith(result -> { assertThat(result).isNotNull(); assertThat(result.role()).isEqualTo(Role.USER); assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); @@ -417,7 +417,7 @@ void testRootsWithoutCapability() { } @Test - void testRootsNotifciationWithEmptyRootsList() { + void testRootsNotificationWithEmptyRootsList() { AtomicReference> rootsRef = new AtomicReference<>(); var mcpServer = McpServer.sync(mcpServerTransportProvider) From 7853efefdabeec61f093b4c048dda0eb95263348 Mon Sep 17 00:00:00 2001 From: jitokim Date: Fri, 11 Apr 2025 13:36:37 +0900 Subject: [PATCH 032/205] feat(completion): implement completion API support (#141) Add completion API support to the MCP protocol implementation: - Add CompleteRequest and CompleteResult schema classes - Implement completion capabilities in ServerCapabilities - Add completion handlers in McpAsyncServer and McpServer - Add completion client methods in McpAsyncClient and McpSyncClient - Add CompletionRefKey and completion specifications in McpServerFeatures - Add integration test for completion functionality - Fix isPresent() check to use isEmpty() in WebMvcSseServerTransportProvider - Replace McpServerFeatures.CompletionRefKey by McpSchemaCompleteReference Co-authored-by: Christian Tzolov Signed-off-by: jitokim --- .../WebFluxSseIntegrationTests.java | 65 +++++++++++--- .../WebMvcSseServerTransportProvider.java | 2 +- .../client/McpAsyncClient.java | 22 +++++ .../client/McpSyncClient.java | 11 +++ .../server/McpAsyncServer.java | 84 +++++++++++++++++++ .../server/McpServer.java | 43 +++++++++- .../server/McpServerFeatures.java | 71 +++++++++++++++- .../spec/McpClientSession.java | 1 + .../modelcontextprotocol/spec/McpSchema.java | 74 ++++++++++++---- 9 files changed, 340 insertions(+), 33 deletions(-) diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java index 80a126441..08619bd31 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java @@ -10,6 +10,7 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiFunction; import java.util.function.Function; import java.util.stream.Collectors; @@ -20,19 +21,12 @@ import io.modelcontextprotocol.server.McpServer; import io.modelcontextprotocol.server.McpServerFeatures; import io.modelcontextprotocol.server.TestUtil; +import io.modelcontextprotocol.server.McpSyncServerExchange; import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpSchema.CallToolResult; -import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; -import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; -import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; -import io.modelcontextprotocol.spec.McpSchema.InitializeResult; -import io.modelcontextprotocol.spec.McpSchema.ModelPreferences; -import io.modelcontextprotocol.spec.McpSchema.Role; -import io.modelcontextprotocol.spec.McpSchema.Root; -import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; -import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.spec.McpSchema.*; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities.CompletionCapabilities; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.params.ParameterizedTest; @@ -759,4 +753,53 @@ void testLoggingNotification(String clientType) { mcpServer.close(); } -} + // --------------------------------------- + // Completion Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : Completion call") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCompletionShouldReturnExpectedSuggestions(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + var expectedValues = List.of("python", "pytorch", "pyside"); + var completionResponse = new McpSchema.CompleteResult(new CompleteResult.CompleteCompletion(expectedValues, 10, // total + true // hasMore + )); + + AtomicReference samplingRequest = new AtomicReference<>(); + BiFunction completionHandler = (mcpSyncServerExchange, + request) -> { + samplingRequest.set(request); + return completionResponse; + }; + + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .capabilities(ServerCapabilities.builder().completions(new CompletionCapabilities()).build()) + .prompts(new McpServerFeatures.SyncPromptSpecification( + new Prompt("code_review", "this is code review prompt", List.of()), + (mcpSyncServerExchange, getPromptRequest) -> null)) + .completions(new McpServerFeatures.SyncCompletionSpecification( + new McpSchema.PromptReference("ref/prompt", "code_review"), completionHandler)) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CompleteRequest request = new CompleteRequest(new PromptReference("ref/prompt", "code_review"), + new CompleteRequest.CompleteArgument("language", "py")); + + CompleteResult result = mcpClient.completeCompletion(request); + + assertThat(result).isNotNull(); + + assertThat(samplingRequest.get().argument().name()).isEqualTo("language"); + assertThat(samplingRequest.get().argument().value()).isEqualTo("py"); + assertThat(samplingRequest.get().ref().type()).isEqualTo("ref/prompt"); + } + + mcpServer.close(); + } + +} \ No newline at end of file diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java index 7bd1aa6c9..fc86cfaa0 100644 --- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java @@ -300,7 +300,7 @@ private ServerResponse handleMessage(ServerRequest request) { return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down"); } - if (!request.param("sessionId").isPresent()) { + if (request.param("sessionId").isEmpty()) { return ServerResponse.badRequest().body(new McpError("Session ID missing in message endpoint")); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java index 1a9c39360..2bc74f258 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java @@ -71,6 +71,7 @@ * * @author Dariusz Jędrzejczyk * @author Christian Tzolov + * @author Jihoon Kim * @see McpClient * @see McpSchema * @see McpClientSession @@ -816,4 +817,25 @@ void setProtocolVersions(List protocolVersions) { this.protocolVersions = protocolVersions; } + // -------------------------- + // Completions + // -------------------------- + private static final TypeReference COMPLETION_COMPLETE_RESULT_TYPE_REF = new TypeReference<>() { + }; + + /** + * Sends a completion/complete request to generate value suggestions based on a given + * reference and argument. This is typically used to provide auto-completion options + * for user input fields. + * @param completeRequest The request containing the prompt or resource reference and + * argument for which to generate completions. + * @return A Mono that completes with the result containing completion suggestions. + * @see McpSchema.CompleteRequest + * @see McpSchema.CompleteResult + */ + public Mono completeCompletion(McpSchema.CompleteRequest completeRequest) { + return this.withInitializationCheck("complete completions", initializedResult -> this.mcpSession + .sendRequest(McpSchema.METHOD_COMPLETION_COMPLETE, completeRequest, COMPLETION_COMPLETE_RESULT_TYPE_REF)); + } + } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java index 8544c3637..c91638a7e 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java @@ -46,6 +46,7 @@ * * @author Dariusz Jędrzejczyk * @author Christian Tzolov + * @author Jihoon Kim * @see McpClient * @see McpAsyncClient * @see McpSchema @@ -334,4 +335,14 @@ public void setLoggingLevel(McpSchema.LoggingLevel loggingLevel) { this.delegate.setLoggingLevel(loggingLevel).block(); } + /** + * Send a completion/complete request. + * @param completeRequest the completion request contains the prompt or resource + * reference and arguments for generating suggestions. + * @return the completion result containing suggested values. + */ + public McpSchema.CompleteResult completeCompletion(McpSchema.CompleteRequest completeRequest) { + return this.delegate.completeCompletion(completeRequest).block(); + } + } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java index 28b63cecb..906cb9a08 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -69,6 +69,7 @@ * * @author Christian Tzolov * @author Dariusz Jędrzejczyk + * @author Jihoon Kim * @see McpServer * @see McpSchema * @see McpClientSession @@ -269,6 +270,8 @@ private static class AsyncServerImpl extends McpAsyncServer { // broadcasting loggingNotification. private LoggingLevel minLoggingLevel = LoggingLevel.DEBUG; + private final ConcurrentHashMap completions = new ConcurrentHashMap<>(); + private List protocolVersions = List.of(McpSchema.LATEST_PROTOCOL_VERSION); AsyncServerImpl(McpServerTransportProvider mcpTransportProvider, ObjectMapper objectMapper, @@ -282,6 +285,7 @@ private static class AsyncServerImpl extends McpAsyncServer { this.resources.putAll(features.resources()); this.resourceTemplates.addAll(features.resourceTemplates()); this.prompts.putAll(features.prompts()); + this.completions.putAll(features.completions()); Map> requestHandlers = new HashMap<>(); @@ -314,6 +318,11 @@ private static class AsyncServerImpl extends McpAsyncServer { requestHandlers.put(McpSchema.METHOD_LOGGING_SET_LEVEL, setLoggerRequestHandler()); } + // Add completion API handlers if the completion capability is enabled + if (this.serverCapabilities.completions() != null) { + requestHandlers.put(McpSchema.METHOD_COMPLETION_COMPLETE, completionCompleteRequestHandler()); + } + Map notificationHandlers = new HashMap<>(); notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_INITIALIZED, (exchange, params) -> Mono.empty()); @@ -706,6 +715,81 @@ private McpServerSession.RequestHandler setLoggerRequestHandler() { }; } + private McpServerSession.RequestHandler completionCompleteRequestHandler() { + return (exchange, params) -> { + McpSchema.CompleteRequest request = parseCompletionParams(params); + + if (request.ref() == null) { + return Mono.error(new McpError("ref must not be null")); + } + + if (request.ref().type() == null) { + return Mono.error(new McpError("type must not be null")); + } + + String type = request.ref().type(); + + // check if the referenced resource exists + if (type.equals("ref/prompt") && request.ref() instanceof McpSchema.PromptReference promptReference) { + McpServerFeatures.AsyncPromptSpecification prompt = this.prompts.get(promptReference.name()); + if (prompt == null) { + return Mono.error(new McpError("Prompt not found: " + promptReference.name())); + } + } + + if (type.equals("ref/resource") + && request.ref() instanceof McpSchema.ResourceReference resourceReference) { + McpServerFeatures.AsyncResourceSpecification resource = this.resources.get(resourceReference.uri()); + if (resource == null) { + return Mono.error(new McpError("Resource not found: " + resourceReference.uri())); + } + } + + McpServerFeatures.AsyncCompletionSpecification specification = this.completions.get(request.ref()); + + if (specification == null) { + return Mono.error(new McpError("AsyncCompletionSpecification not found: " + request.ref())); + } + + return specification.completionHandler().apply(exchange, request); + }; + } + + /** + * Parses the raw JSON-RPC request parameters into a + * {@link McpSchema.CompleteRequest} object. + *

+ * This method manually extracts the `ref` and `argument` fields from the input + * map, determines the correct reference type (either prompt or resource), and + * constructs a fully-typed {@code CompleteRequest} instance. + * @param object the raw request parameters, expected to be a Map containing "ref" + * and "argument" entries. + * @return a {@link McpSchema.CompleteRequest} representing the structured + * completion request. + * @throws IllegalArgumentException if the "ref" type is not recognized. + */ + @SuppressWarnings("unchecked") + private McpSchema.CompleteRequest parseCompletionParams(Object object) { + Map params = (Map) object; + Map refMap = (Map) params.get("ref"); + Map argMap = (Map) params.get("argument"); + + String refType = (String) refMap.get("type"); + + McpSchema.CompleteReference ref = switch (refType) { + case "ref/prompt" -> new McpSchema.PromptReference(refType, (String) refMap.get("name")); + case "ref/resource" -> new McpSchema.ResourceReference(refType, (String) refMap.get("uri")); + default -> throw new IllegalArgumentException("Invalid ref type: " + refType); + }; + + String argName = (String) argMap.get("name"); + String argValue = (String) argMap.get("value"); + McpSchema.CompleteRequest.CompleteArgument argument = new McpSchema.CompleteRequest.CompleteArgument( + argName, argValue); + + return new McpSchema.CompleteRequest(ref, argument); + } + // --------------------------------------- // Sampling // --------------------------------------- diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java index 60434a841..84089703c 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java @@ -115,6 +115,7 @@ * * @author Christian Tzolov * @author Dariusz Jędrzejczyk + * @author Jihoon Kim * @see McpAsyncServer * @see McpSyncServer * @see McpServerTransportProvider @@ -192,6 +193,8 @@ class AsyncSpecification { */ private final Map prompts = new HashMap<>(); + private final Map completions = new HashMap<>(); + private final List, Mono>> rootsChangeHandlers = new ArrayList<>(); private Duration requestTimeout = Duration.ofSeconds(10); // Default timeout @@ -581,7 +584,8 @@ public AsyncSpecification objectMapper(ObjectMapper objectMapper) { */ public McpAsyncServer build() { var features = new McpServerFeatures.Async(this.serverInfo, this.serverCapabilities, this.tools, - this.resources, this.resourceTemplates, this.prompts, this.rootsChangeHandlers, this.instructions); + this.resources, this.resourceTemplates, this.prompts, this.completions, this.rootsChangeHandlers, + this.instructions); var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); return new McpAsyncServer(this.transportProvider, mapper, features, this.requestTimeout); } @@ -635,6 +639,8 @@ class SyncSpecification { */ private final Map prompts = new HashMap<>(); + private final Map completions = new HashMap<>(); + private final List>> rootsChangeHandlers = new ArrayList<>(); private Duration requestTimeout = Duration.ofSeconds(10); // Default timeout @@ -957,6 +963,37 @@ public SyncSpecification prompts(McpServerFeatures.SyncPromptSpecification... pr return this; } + /** + * Registers multiple completions with their handlers using a List. This method is + * useful when completions need to be added in bulk from a collection. + * @param completions List of completion specifications. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if completions is null + * @see #completions(McpServerFeatures.SyncCompletionSpecification...) + */ + public SyncSpecification completions(List completions) { + Assert.notNull(completions, "Completions list must not be null"); + for (McpServerFeatures.SyncCompletionSpecification completion : completions) { + this.completions.put(completion.referenceKey(), completion); + } + return this; + } + + /** + * Registers multiple completions with their handlers using varargs. This method + * is useful when completions are defined inline and added directly. + * @param completions Array of completion specifications. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if completions is null + */ + public SyncSpecification completions(McpServerFeatures.SyncCompletionSpecification... completions) { + Assert.notNull(completions, "Completions list must not be null"); + for (McpServerFeatures.SyncCompletionSpecification completion : completions) { + this.completions.put(completion.referenceKey(), completion); + } + return this; + } + /** * Registers a consumer that will be notified when the list of roots changes. This * is useful for updating resource availability dynamically, such as when new @@ -1023,8 +1060,8 @@ public SyncSpecification objectMapper(ObjectMapper objectMapper) { */ public McpSyncServer build() { McpServerFeatures.Sync syncFeatures = new McpServerFeatures.Sync(this.serverInfo, this.serverCapabilities, - this.tools, this.resources, this.resourceTemplates, this.prompts, this.rootsChangeHandlers, - this.instructions); + this.tools, this.resources, this.resourceTemplates, this.prompts, this.completions, + this.rootsChangeHandlers, this.instructions); McpServerFeatures.Async asyncFeatures = McpServerFeatures.Async.fromSync(syncFeatures); var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); var asyncServer = new McpAsyncServer(this.transportProvider, mapper, asyncFeatures, this.requestTimeout); diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java index e0f337b78..8311f5d41 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java @@ -21,6 +21,7 @@ * MCP server features specification that a particular server can choose to support. * * @author Dariusz Jędrzejczyk + * @author Jihoon Kim */ public class McpServerFeatures { @@ -41,6 +42,7 @@ record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities s List tools, Map resources, List resourceTemplates, Map prompts, + Map completions, List, Mono>> rootsChangeConsumers, String instructions) { @@ -60,6 +62,7 @@ record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities s List tools, Map resources, List resourceTemplates, Map prompts, + Map completions, List, Mono>> rootsChangeConsumers, String instructions) { @@ -67,7 +70,8 @@ record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities s this.serverInfo = serverInfo; this.serverCapabilities = (serverCapabilities != null) ? serverCapabilities - : new McpSchema.ServerCapabilities(null, // experimental + : new McpSchema.ServerCapabilities(null, // completions + null, // experimental new McpSchema.ServerCapabilities.LoggingCapabilities(), // Enable // logging // by @@ -81,6 +85,7 @@ record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities s this.resources = (resources != null) ? resources : Map.of(); this.resourceTemplates = (resourceTemplates != null) ? resourceTemplates : List.of(); this.prompts = (prompts != null) ? prompts : Map.of(); + this.completions = (completions != null) ? completions : Map.of(); this.rootsChangeConsumers = (rootsChangeConsumers != null) ? rootsChangeConsumers : List.of(); this.instructions = instructions; } @@ -109,6 +114,11 @@ static Async fromSync(Sync syncSpec) { prompts.put(key, AsyncPromptSpecification.fromSync(prompt)); }); + Map completions = new HashMap<>(); + syncSpec.completions().forEach((key, completion) -> { + completions.put(key, AsyncCompletionSpecification.fromSync(completion)); + }); + List, Mono>> rootChangeConsumers = new ArrayList<>(); for (var rootChangeConsumer : syncSpec.rootsChangeConsumers()) { @@ -118,7 +128,7 @@ static Async fromSync(Sync syncSpec) { } return new Async(syncSpec.serverInfo(), syncSpec.serverCapabilities(), tools, resources, - syncSpec.resourceTemplates(), prompts, rootChangeConsumers, syncSpec.instructions()); + syncSpec.resourceTemplates(), prompts, completions, rootChangeConsumers, syncSpec.instructions()); } } @@ -140,6 +150,7 @@ record Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities se Map resources, List resourceTemplates, Map prompts, + Map completions, List>> rootsChangeConsumers, String instructions) { /** @@ -159,6 +170,7 @@ record Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities se Map resources, List resourceTemplates, Map prompts, + Map completions, List>> rootsChangeConsumers, String instructions) { @@ -166,7 +178,8 @@ record Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities se this.serverInfo = serverInfo; this.serverCapabilities = (serverCapabilities != null) ? serverCapabilities - : new McpSchema.ServerCapabilities(null, // experimental + : new McpSchema.ServerCapabilities(null, // completions + null, // experimental new McpSchema.ServerCapabilities.LoggingCapabilities(), // Enable // logging // by @@ -180,6 +193,7 @@ record Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities se this.resources = (resources != null) ? resources : new HashMap<>(); this.resourceTemplates = (resourceTemplates != null) ? resourceTemplates : new ArrayList<>(); this.prompts = (prompts != null) ? prompts : new HashMap<>(); + this.completions = (completions != null) ? completions : new HashMap<>(); this.rootsChangeConsumers = (rootsChangeConsumers != null) ? rootsChangeConsumers : new ArrayList<>(); this.instructions = instructions; } @@ -325,6 +339,44 @@ static AsyncPromptSpecification fromSync(SyncPromptSpecification prompt) { } } + /** + * Specification of a completion handler function with asynchronous execution support. + * Completions generate AI model outputs based on prompt or resource references and + * user-provided arguments. This abstraction enables: + *

    + *
  • Customizable response generation logic + *
  • Parameter-driven template expansion + *
  • Dynamic interaction with connected clients + *
+ * + * @param referenceKey The unique key representing the completion reference. + * @param completionHandler The asynchronous function that processes completion + * requests and returns results. The first argument is an + * {@link McpAsyncServerExchange} used to interact with the client. The second + * argument is a {@link io.modelcontextprotocol.spec.McpSchema.CompleteRequest}. + */ + public record AsyncCompletionSpecification(McpSchema.CompleteReference referenceKey, + BiFunction> completionHandler) { + + /** + * Converts a synchronous {@link SyncCompletionSpecification} into an + * {@link AsyncCompletionSpecification} by wrapping the handler in a bounded + * elastic scheduler for safe non-blocking execution. + * @param completion the synchronous completion specification + * @return an asynchronous wrapper of the provided sync specification, or + * {@code null} if input is null + */ + static AsyncCompletionSpecification fromSync(SyncCompletionSpecification completion) { + if (completion == null) { + return null; + } + return new AsyncCompletionSpecification(completion.referenceKey(), + (exchange, request) -> Mono.fromCallable( + () -> completion.completionHandler().apply(new McpSyncServerExchange(exchange), request)) + .subscribeOn(Schedulers.boundedElastic())); + } + } + /** * Specification of a tool with its synchronous handler function. Tools are the * primary way for MCP servers to expose functionality to AI models. Each tool @@ -431,4 +483,17 @@ public record SyncPromptSpecification(McpSchema.Prompt prompt, BiFunction promptHandler) { } + /** + * Specification of a completion handler function with synchronous execution support. + * + * @param referenceKey The unique key representing the completion reference. + * @param completionHandler The synchronous function that processes completion + * requests and returns results. The first argument is an + * {@link McpSyncServerExchange} used to interact with the client. The second argument + * is a {@link io.modelcontextprotocol.spec.McpSchema.CompleteRequest}. + */ + public record SyncCompletionSpecification(McpSchema.CompleteReference referenceKey, + BiFunction completionHandler) { + } + } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java index 0895e02b0..c1f42e3fb 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java @@ -238,6 +238,7 @@ public Mono sendRequest(String method, Object requestParams, TypeReferenc }); }).timeout(this.requestTimeout).handle((jsonRpcResponse, sink) -> { if (jsonRpcResponse.error() != null) { + logger.error("Error handling request: {}", jsonRpcResponse.error()); sink.error(new McpError(jsonRpcResponse.error())); } else { diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java index 6eb5159f6..55fdc1724 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java @@ -79,6 +79,8 @@ private McpSchema() { public static final String METHOD_NOTIFICATION_PROMPTS_LIST_CHANGED = "notifications/prompts/list_changed"; + public static final String METHOD_COMPLETION_COMPLETE = "completion/complete"; + // Logging Methods public static final String METHOD_LOGGING_SET_LEVEL = "logging/setLevel"; @@ -314,12 +316,16 @@ public ClientCapabilities build() { @JsonInclude(JsonInclude.Include.NON_ABSENT) @JsonIgnoreProperties(ignoreUnknown = true) public record ServerCapabilities( // @formatter:off + @JsonProperty("completions") CompletionCapabilities completions, @JsonProperty("experimental") Map experimental, @JsonProperty("logging") LoggingCapabilities logging, @JsonProperty("prompts") PromptCapabilities prompts, @JsonProperty("resources") ResourceCapabilities resources, @JsonProperty("tools") ToolCapabilities tools) { + @JsonInclude(JsonInclude.Include.NON_ABSENT) + public record CompletionCapabilities() { + } @JsonInclude(JsonInclude.Include.NON_ABSENT) public record LoggingCapabilities() { @@ -347,12 +353,18 @@ public static Builder builder() { public static class Builder { + private CompletionCapabilities completions; private Map experimental; private LoggingCapabilities logging = new LoggingCapabilities(); private PromptCapabilities prompts; private ResourceCapabilities resources; private ToolCapabilities tools; + public Builder completions(CompletionCapabilities completions) { + this.completions = completions; + return this; + } + public Builder experimental(Map experimental) { this.experimental = experimental; return this; @@ -379,7 +391,7 @@ public Builder tools(Boolean listChanged) { } public ServerCapabilities build() { - return new ServerCapabilities(experimental, logging, prompts, resources, tools); + return new ServerCapabilities(completions, experimental, logging, prompts, resources, tools); } } } // @formatter:on @@ -1173,31 +1185,63 @@ public record SetLevelRequest(@JsonProperty("level") LoggingLevel level) { // --------------------------- // Autocomplete // --------------------------- - public record CompleteRequest(PromptOrResourceReference ref, CompleteArgument argument) implements Request { - public sealed interface PromptOrResourceReference permits PromptReference, ResourceReference { + public sealed interface CompleteReference permits PromptReference, ResourceReference { + + String type(); + + String identifier(); - String type(); + } + + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record PromptReference(// @formatter:off + @JsonProperty("type") String type, + @JsonProperty("name") String name) implements McpSchema.CompleteReference { + public PromptReference(String name) { + this("ref/prompt", name); } - public record PromptReference(// @formatter:off - @JsonProperty("type") String type, - @JsonProperty("name") String name) implements PromptOrResourceReference { - }// @formatter:on + @Override + public String identifier() { + return name(); + } + }// @formatter:on - public record ResourceReference(// @formatter:off - @JsonProperty("type") String type, - @JsonProperty("uri") String uri) implements PromptOrResourceReference { - }// @formatter:on + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ResourceReference(// @formatter:off + @JsonProperty("type") String type, + @JsonProperty("uri") String uri) implements McpSchema.CompleteReference { - public record CompleteArgument(// @formatter:off + public ResourceReference(String uri) { + this("ref/resource", uri); + } + + @Override + public String identifier() { + return uri(); + } + }// @formatter:on + + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record CompleteRequest(// @formatter:off + @JsonProperty("ref") McpSchema.CompleteReference ref, + @JsonProperty("argument") CompleteArgument argument) implements Request { + + public record CompleteArgument( @JsonProperty("name") String name, @JsonProperty("value") String value) { }// @formatter:on } - public record CompleteResult(CompleteCompletion completion) { - public record CompleteCompletion(// @formatter:off + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record CompleteResult(@JsonProperty("values") CompleteCompletion completion) { // @formatter:off + + public record CompleteCompletion( @JsonProperty("values") List values, @JsonProperty("total") Integer total, @JsonProperty("hasMore") Boolean hasMore) { From 734d1732d6d3e74cd427ac5a4dba95a02ba08618 Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Thu, 17 Apr 2025 09:28:08 +0200 Subject: [PATCH 033/205] Refactor: Simplify ServerCapabilities builder API for completions Signed-off-by: Christian Tzolov --- .../io/modelcontextprotocol/WebFluxSseIntegrationTests.java | 2 +- mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java index 08619bd31..660f814da 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java @@ -774,7 +774,7 @@ void testCompletionShouldReturnExpectedSuggestions(String clientType) { }; var mcpServer = McpServer.sync(mcpServerTransportProvider) - .capabilities(ServerCapabilities.builder().completions(new CompletionCapabilities()).build()) + .capabilities(ServerCapabilities.builder().completions().build()) .prompts(new McpServerFeatures.SyncPromptSpecification( new Prompt("code_review", "this is code review prompt", List.of()), (mcpSyncServerExchange, getPromptRequest) -> null)) diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java index 55fdc1724..e7e338030 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java @@ -360,8 +360,8 @@ public static class Builder { private ResourceCapabilities resources; private ToolCapabilities tools; - public Builder completions(CompletionCapabilities completions) { - this.completions = completions; + public Builder completions() { + this.completions = new CompletionCapabilities(); return this; } From 344f1b838340efbcf0cfcd428358bdb77e9a533c Mon Sep 17 00:00:00 2001 From: E550448 Date: Sun, 13 Apr 2025 12:54:48 +0200 Subject: [PATCH 034/205] feat(mcp): resolve absolute and relative message endpoint URIs (#150) Improve endpoint URI handling by supporting both relative paths and properly validated absolute URIs. - Implement URI resolution in HttpClientSseClientTransport: - Change baseUri field from String to URI type - Add Utils.resolveUri method to handle both absolute and relative URIs - Resolve relative URIs against the base URI - Validate absolute URIs to ensure they match base URI's scheme, authority, and path - Add parameterized tests for various URI resolution scenarios - Add ByteBuddy dependency for HttpClient mocking and update Mockito Signed-off-by: Christian Tzolov --- README.md | 2 +- mcp/pom.xml | 14 +++++ .../HttpClientSseClientTransport.java | 11 ++-- .../io/modelcontextprotocol/util/Utils.java | 56 ++++++++++++++++++- .../HttpClientSseClientTransportTests.java | 29 +++++++++- .../modelcontextprotocol/util/UtilsTests.java | 29 ++++++++++ pom.xml | 5 +- 7 files changed, 136 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index ca87736cd..9fc17306e 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ A set of projects that provide Java SDK integration for the [Model Context Protocol](https://modelcontextprotocol.org/docs/concepts/architecture). This SDK enables Java applications to interact with AI models and tools through a standardized interface, supporting both synchronous and asynchronous communication patterns. -## 📚 Reference Documentation +## 📚 Reference Documentation #### MCP Java SDK documentation For comprehensive guides and SDK API documentation, visit the [MCP Java SDK Reference Documentation](https://modelcontextprotocol.io/sdk/java/mcp-overview). diff --git a/mcp/pom.xml b/mcp/pom.xml index 6b0f4a9fe..17693ab32 100644 --- a/mcp/pom.xml +++ b/mcp/pom.xml @@ -126,12 +126,26 @@ ${junit.version} test + + org.junit.jupiter + junit-jupiter-params + ${junit.version} + test + org.mockito mockito-core ${mockito.version} test + + + + net.bytebuddy + byte-buddy + ${byte-buddy.version} + test + io.projectreactor reactor-test diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java index 632d3844a..99cf2a625 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java @@ -24,6 +24,7 @@ import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.Utils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.publisher.Mono; @@ -69,7 +70,7 @@ public class HttpClientSseClientTransport implements McpClientTransport { private static final String DEFAULT_SSE_ENDPOINT = "/sse"; /** Base URI for the MCP server */ - private final String baseUri; + private final URI baseUri; /** SSE endpoint path */ private final String sseEndpoint; @@ -178,7 +179,7 @@ public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, HttpReques Assert.hasText(sseEndpoint, "sseEndpoint must not be empty"); Assert.notNull(httpClient, "httpClient must not be null"); Assert.notNull(requestBuilder, "requestBuilder must not be null"); - this.baseUri = baseUri; + this.baseUri = URI.create(baseUri); this.sseEndpoint = sseEndpoint; this.objectMapper = objectMapper; this.httpClient = httpClient; @@ -340,7 +341,8 @@ public Mono connect(Function, Mono> h CompletableFuture future = new CompletableFuture<>(); connectionFuture.set(future); - sseClient.subscribe(this.baseUri + this.sseEndpoint, new FlowSseClient.SseEventHandler() { + URI clientUri = Utils.resolveUri(this.baseUri, this.sseEndpoint); + sseClient.subscribe(clientUri.toString(), new FlowSseClient.SseEventHandler() { @Override public void onEvent(SseEvent event) { if (isClosing) { @@ -412,7 +414,8 @@ public Mono sendMessage(JSONRPCMessage message) { try { String jsonText = this.objectMapper.writeValueAsString(message); - HttpRequest request = this.requestBuilder.uri(URI.create(this.baseUri + endpoint)) + URI requestUri = Utils.resolveUri(baseUri, endpoint); + HttpRequest request = this.requestBuilder.uri(requestUri) .POST(HttpRequest.BodyPublishers.ofString(jsonText)) .build(); diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java b/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java index 0f799ca0f..8e654e596 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java +++ b/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java @@ -4,11 +4,12 @@ package io.modelcontextprotocol.util; +import reactor.util.annotation.Nullable; + +import java.net.URI; import java.util.Collection; import java.util.Map; -import reactor.util.annotation.Nullable; - /** * Miscellaneous utility methods. * @@ -52,4 +53,55 @@ public static boolean isEmpty(@Nullable Map map) { return (map == null || map.isEmpty()); } + /** + * Resolves the given endpoint URL against the base URL. + *
    + *
  • If the endpoint URL is relative, it will be resolved against the base URL.
  • + *
  • If the endpoint URL is absolute, it will be validated to ensure it matches the + * base URL's scheme, authority, and path prefix.
  • + *
  • If validation fails for an absolute URL, an {@link IllegalArgumentException} is + * thrown.
  • + *
+ * @param baseUrl The base URL (must be absolute) + * @param endpointUrl The endpoint URL (can be relative or absolute) + * @return The resolved endpoint URI + * @throws IllegalArgumentException If the absolute endpoint URL does not match the + * base URL or URI is malformed + */ + public static URI resolveUri(URI baseUrl, String endpointUrl) { + URI endpointUri = URI.create(endpointUrl); + if (endpointUri.isAbsolute() && !isUnderBaseUri(baseUrl, endpointUri)) { + throw new IllegalArgumentException("Absolute endpoint URL does not match the base URL."); + } + else { + return baseUrl.resolve(endpointUri); + } + } + + /** + * Checks if the given absolute endpoint URI falls under the base URI. It validates + * the scheme, authority (host and port), and ensures that the base path is a prefix + * of the endpoint path. + * @param baseUri The base URI + * @param endpointUri The endpoint URI to check + * @return true if endpointUri is within baseUri's hierarchy, false otherwise + */ + private static boolean isUnderBaseUri(URI baseUri, URI endpointUri) { + if (!baseUri.getScheme().equals(endpointUri.getScheme()) + || !baseUri.getAuthority().equals(endpointUri.getAuthority())) { + return false; + } + + URI normalizedBase = baseUri.normalize(); + URI normalizedEndpoint = endpointUri.normalize(); + + String basePath = normalizedBase.getPath(); + String endpointPath = normalizedEndpoint.getPath(); + + if (basePath.endsWith("/")) { + basePath = basePath.substring(0, basePath.length() - 1); + } + return endpointPath.startsWith(basePath); + } + } diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java index e5178c0ee..762264de3 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java @@ -7,12 +7,13 @@ import java.net.URI; import java.net.http.HttpClient; import java.net.http.HttpRequest; +import java.net.http.HttpResponse; import java.time.Duration; import java.util.Map; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Consumer; import java.util.function.Function; import io.modelcontextprotocol.spec.McpSchema; @@ -21,6 +22,8 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; +import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; import org.testcontainers.containers.GenericContainer; import org.testcontainers.containers.wait.strategy.Wait; import reactor.core.publisher.Mono; @@ -31,6 +34,9 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import com.fasterxml.jackson.databind.ObjectMapper; @@ -364,4 +370,25 @@ void testChainedCustomizations() { customizedTransport.closeGracefully().block(); } + @Test + @SuppressWarnings("unchecked") + void testResolvingClientEndpoint() { + HttpClient httpClient = Mockito.mock(HttpClient.class); + HttpResponse httpResponse = Mockito.mock(HttpResponse.class); + CompletableFuture> future = new CompletableFuture<>(); + future.complete(httpResponse); + when(httpClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))).thenReturn(future); + + HttpClientSseClientTransport transport = new HttpClientSseClientTransport(httpClient, HttpRequest.newBuilder(), + "http://example.com", "http://example.com/sse", new ObjectMapper()); + + transport.connect(Function.identity()); + + ArgumentCaptor httpRequestCaptor = ArgumentCaptor.forClass(HttpRequest.class); + verify(httpClient).sendAsync(httpRequestCaptor.capture(), any(HttpResponse.BodyHandler.class)); + assertThat(httpRequestCaptor.getValue().uri()).isEqualTo(URI.create("http://example.com/sse")); + + transport.closeGracefully().block(); + } + } diff --git a/mcp/src/test/java/io/modelcontextprotocol/util/UtilsTests.java b/mcp/src/test/java/io/modelcontextprotocol/util/UtilsTests.java index aced20cbc..0f2e689b5 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/util/UtilsTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/util/UtilsTests.java @@ -6,12 +6,17 @@ import org.junit.jupiter.api.Test; +import java.net.URI; import java.util.Collection; import java.util.List; import java.util.Map; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; class UtilsTests { @@ -37,4 +42,28 @@ void testMapIsEmpty() { assertFalse(Utils.isEmpty(Map.of("key", "value"))); } + @ParameterizedTest + @CsvSource({ + // relative endpoints + "http://localhost:8080/root, /api/v1, http://localhost:8080/api/v1", + "http://localhost:8080/root/, api, http://localhost:8080/root/api", + "http://localhost:8080, /api, http://localhost:8080/api", + // absolute endpoints matching base + "http://localhost:8080/root, http://localhost:8080/root/api/v1, http://localhost:8080/root/api/v1", + "http://localhost:8080/root, http://localhost:8080/root, http://localhost:8080/root" }) + void testValidUriResolution(String baseUrl, String endpoint, String expectedResult) { + URI result = Utils.resolveUri(URI.create(baseUrl), endpoint); + assertThat(result.toString()).isEqualTo(expectedResult); + } + + @ParameterizedTest + @CsvSource({ "http://localhost:8080/root, http://localhost:8080/other/api", + "http://localhost:8080/root, http://otherhost/api", + "http://localhost:8080/root, http://localhost:9090/root/api" }) + void testAbsoluteUriNotMatchingBase(String baseUrl, String endpoint) { + assertThatThrownBy(() -> Utils.resolveUri(URI.create(baseUrl), endpoint)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("does not match the base URL"); + } + } \ No newline at end of file diff --git a/pom.xml b/pom.xml index ff485b75d..9be256ccf 100644 --- a/pom.xml +++ b/pom.xml @@ -60,8 +60,9 @@ 3.26.3 5.10.2 - 5.11.0 + 5.17.0 1.20.4 + 1.17.5 2.0.16 1.5.15 @@ -356,4 +357,4 @@ - \ No newline at end of file + From e4091f458a28e31f87a517f411fe9d18811027a6 Mon Sep 17 00:00:00 2001 From: "jie.bao" Date: Fri, 18 Apr 2025 09:34:55 +0800 Subject: [PATCH 035/205] feat(completion): fix the schema about CompleteResult /** * The server's response to a completion/complete request */ export interface CompleteResult extends Result { completion: { /** * An array of completion values. Must not exceed 100 items. */ values: string[]; /** * The total number of completion options available. This can exceed the number of values actually sent in the response. */ total?: number; /** * Indicates whether there are additional completion options beyond those provided in the current response, even if the exact total is unknown. */ hasMore?: boolean; }; } --- mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java index e7e338030..e77edb3b7 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java @@ -1239,7 +1239,7 @@ public record CompleteArgument( @JsonInclude(JsonInclude.Include.NON_ABSENT) @JsonIgnoreProperties(ignoreUnknown = true) - public record CompleteResult(@JsonProperty("values") CompleteCompletion completion) { // @formatter:off + public record CompleteResult(@JsonProperty("completion") CompleteCompletion completion) { // @formatter:off public record CompleteCompletion( @JsonProperty("values") List values, From 41c6bd9af09462a87064dc035d5e123d7f1eae58 Mon Sep 17 00:00:00 2001 From: JermaineHua Date: Thu, 17 Apr 2025 22:25:00 +0800 Subject: [PATCH 036/205] Fix method not found error msg for server Signed-off-by: JermaineHua --- .../io/modelcontextprotocol/spec/McpClientSession.java | 2 +- .../io/modelcontextprotocol/spec/McpServerSession.java | 10 ++-------- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java index c1f42e3fb..9ed0d8edd 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java @@ -178,7 +178,7 @@ private Mono handleIncomingRequest(McpSchema.JSONRPCR record MethodNotFoundError(String method, String message, Object data) { } - public static MethodNotFoundError getMethodNotFoundError(String method) { + private MethodNotFoundError getMethodNotFoundError(String method) { switch (method) { case McpSchema.METHOD_ROOTS_LIST: return new MethodNotFoundError(method, "Roots not supported", diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java index 46c356cdd..64315095b 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java @@ -257,14 +257,8 @@ private Mono handleIncomingNotification(McpSchema.JSONRPCNotification noti record MethodNotFoundError(String method, String message, Object data) { } - static MethodNotFoundError getMethodNotFoundError(String method) { - switch (method) { - case McpSchema.METHOD_ROOTS_LIST: - return new MethodNotFoundError(method, "Roots not supported", - Map.of("reason", "Client does not have roots capability")); - default: - return new MethodNotFoundError(method, "Method not found: " + method, null); - } + private MethodNotFoundError getMethodNotFoundError(String method) { + return new MethodNotFoundError(method, "Method not found: " + method, null); } @Override From 04046ca05b6b90f9a6ec2f40236c69470b878fe6 Mon Sep 17 00:00:00 2001 From: JermaineHua Date: Wed, 16 Apr 2025 23:10:59 +0800 Subject: [PATCH 037/205] Optimize client nested streams in McpClientSession (#33) Signed-off-by: JermaineHua --- .../spec/McpClientSession.java | 31 ++++++++++++------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java index 9ed0d8edd..a25f38c5c 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java @@ -122,7 +122,12 @@ public McpClientSession(Duration requestTimeout, McpClientTransport transport, // Observation associated with the individual message - it can be used to // create child Observation and emit it together with the message to the // consumer - this.connection = this.transport.connect(mono -> mono.doOnNext(message -> { + this.connection = this.transport.connect(mono -> mono.doOnNext(message -> handle(message).subscribe())) + .subscribe(); + } + + public Mono handle(McpSchema.JSONRPCMessage message) { + return Mono.defer(() -> { if (message instanceof McpSchema.JSONRPCResponse response) { logger.debug("Received Response: {}", response); var sink = pendingResponses.remove(response.id()); @@ -132,23 +137,27 @@ public McpClientSession(Duration requestTimeout, McpClientTransport transport, else { sink.success(response); } + return Mono.empty(); } else if (message instanceof McpSchema.JSONRPCRequest request) { logger.debug("Received request: {}", request); - handleIncomingRequest(request).subscribe(response -> transport.sendMessage(response).subscribe(), - error -> { - var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), - null, new McpSchema.JSONRPCResponse.JSONRPCError( - McpSchema.ErrorCodes.INTERNAL_ERROR, error.getMessage(), null)); - transport.sendMessage(errorResponse).subscribe(); - }); + return handleIncomingRequest(request).onErrorResume(error -> { + var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, + new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, + error.getMessage(), null)); + return this.transport.sendMessage(errorResponse).then(Mono.empty()); + }).flatMap(this.transport::sendMessage); } else if (message instanceof McpSchema.JSONRPCNotification notification) { logger.debug("Received notification: {}", notification); - handleIncomingNotification(notification).subscribe(null, - error -> logger.error("Error handling notification: {}", error.getMessage())); + return handleIncomingNotification(notification) + .doOnError(error -> logger.error("Error handling notification: {}", error.getMessage())); } - })).subscribe(); + else { + logger.warn("Received unknown message type: {}", message); + return Mono.empty(); + } + }); } /** From 866732c3833e863ea145c6e1dfa32b9d089211e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Wed, 23 Apr 2025 11:06:26 +0200 Subject: [PATCH 038/205] Polish #33 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Dariusz Jędrzejczyk --- .../spec/McpClientSession.java | 58 +++++++++---------- 1 file changed, 27 insertions(+), 31 deletions(-) diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java index a25f38c5c..6eca34757 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java @@ -122,42 +122,38 @@ public McpClientSession(Duration requestTimeout, McpClientTransport transport, // Observation associated with the individual message - it can be used to // create child Observation and emit it together with the message to the // consumer - this.connection = this.transport.connect(mono -> mono.doOnNext(message -> handle(message).subscribe())) - .subscribe(); + this.connection = this.transport.connect(mono -> mono.doOnNext(this::handle)).subscribe(); } - public Mono handle(McpSchema.JSONRPCMessage message) { - return Mono.defer(() -> { - if (message instanceof McpSchema.JSONRPCResponse response) { - logger.debug("Received Response: {}", response); - var sink = pendingResponses.remove(response.id()); - if (sink == null) { - logger.warn("Unexpected response for unknown id {}", response.id()); - } - else { - sink.success(response); - } - return Mono.empty(); - } - else if (message instanceof McpSchema.JSONRPCRequest request) { - logger.debug("Received request: {}", request); - return handleIncomingRequest(request).onErrorResume(error -> { - var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, - new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, - error.getMessage(), null)); - return this.transport.sendMessage(errorResponse).then(Mono.empty()); - }).flatMap(this.transport::sendMessage); - } - else if (message instanceof McpSchema.JSONRPCNotification notification) { - logger.debug("Received notification: {}", notification); - return handleIncomingNotification(notification) - .doOnError(error -> logger.error("Error handling notification: {}", error.getMessage())); + private void handle(McpSchema.JSONRPCMessage message) { + if (message instanceof McpSchema.JSONRPCResponse response) { + logger.debug("Received Response: {}", response); + var sink = pendingResponses.remove(response.id()); + if (sink == null) { + logger.warn("Unexpected response for unknown id {}", response.id()); } else { - logger.warn("Received unknown message type: {}", message); - return Mono.empty(); + sink.success(response); } - }); + } + else if (message instanceof McpSchema.JSONRPCRequest request) { + logger.debug("Received request: {}", request); + handleIncomingRequest(request).onErrorResume(error -> { + var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, + new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, + error.getMessage(), null)); + return this.transport.sendMessage(errorResponse).then(Mono.empty()); + }).flatMap(this.transport::sendMessage).subscribe(); + } + else if (message instanceof McpSchema.JSONRPCNotification notification) { + logger.debug("Received notification: {}", notification); + handleIncomingNotification(notification) + .doOnError(error -> logger.error("Error handling notification: {}", error.getMessage())) + .subscribe(); + } + else { + logger.warn("Received unknown message type: {}", message); + } } /** From 86e3e9048f53b706849a2a58a11aae70c3a1f391 Mon Sep 17 00:00:00 2001 From: jito Date: Wed, 23 Apr 2025 23:27:10 +0900 Subject: [PATCH 039/205] Fix typo in WebFluxSseIntegrationTests (#142) Signed-off-by: jitokim From f70b98b4b4160ea590a0c845ee3e2a7357bdcae9 Mon Sep 17 00:00:00 2001 From: Richie Caputo <43445060+arcaputo3@users.noreply.github.com> Date: Wed, 23 Apr 2025 10:47:48 -0400 Subject: [PATCH 040/205] feat(schema): add support for JSON Schema $defs and definitions (#146) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Added support for $defs and definitions properties in JsonSchema record to handle JSON Schema references properly. Added tests to verify both formats work correctly. The JsonSchema test approach uses serialization/deserialization round-trip validation instead of property-by-property assertions. This makes tests more maintainable and less likely to break when new properties are added. 🤖 Generated with [Claude Code](https://claude.ai/code) --------- Co-authored-by: Claude --- .../modelcontextprotocol/spec/McpSchema.java | 4 +- .../spec/McpSchemaTests.java | 129 ++++++++++++++++++ 2 files changed, 132 insertions(+), 1 deletion(-) diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java index e77edb3b7..8df8a1584 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java @@ -703,7 +703,9 @@ public record JsonSchema( // @formatter:off @JsonProperty("type") String type, @JsonProperty("properties") Map properties, @JsonProperty("required") List required, - @JsonProperty("additionalProperties") Boolean additionalProperties) { + @JsonProperty("additionalProperties") Boolean additionalProperties, + @JsonProperty("$defs") Map defs, + @JsonProperty("definitions") Map definitions) { } // @formatter:on /** diff --git a/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java b/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java index a41fc095f..ff78c1bfc 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java @@ -9,6 +9,7 @@ import java.util.List; import java.util.Map; +import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.exc.InvalidTypeIdException; import io.modelcontextprotocol.spec.McpSchema.TextResourceContents; @@ -449,6 +450,92 @@ void testGetPromptResult() throws Exception { // Tool Tests + @Test + void testJsonSchema() throws Exception { + String schemaJson = """ + { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "address": { + "$ref": "#/$defs/Address" + } + }, + "required": ["name"], + "$defs": { + "Address": { + "type": "object", + "properties": { + "street": {"type": "string"}, + "city": {"type": "string"} + }, + "required": ["street", "city"] + } + } + } + """; + + // Deserialize the original string to a JsonSchema object + McpSchema.JsonSchema schema = mapper.readValue(schemaJson, McpSchema.JsonSchema.class); + + // Serialize the object back to a string + String serialized = mapper.writeValueAsString(schema); + + // Deserialize again + McpSchema.JsonSchema deserialized = mapper.readValue(serialized, McpSchema.JsonSchema.class); + + // Serialize one more time and compare with the first serialization + String serializedAgain = mapper.writeValueAsString(deserialized); + + // The two serialized strings should be the same + assertThatJson(serializedAgain).when(Option.IGNORING_ARRAY_ORDER).isEqualTo(json(serialized)); + } + + @Test + void testJsonSchemaWithDefinitions() throws Exception { + String schemaJson = """ + { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "address": { + "$ref": "#/definitions/Address" + } + }, + "required": ["name"], + "definitions": { + "Address": { + "type": "object", + "properties": { + "street": {"type": "string"}, + "city": {"type": "string"} + }, + "required": ["street", "city"] + } + } + } + """; + + // Deserialize the original string to a JsonSchema object + McpSchema.JsonSchema schema = mapper.readValue(schemaJson, McpSchema.JsonSchema.class); + + // Serialize the object back to a string + String serialized = mapper.writeValueAsString(schema); + + // Deserialize again + McpSchema.JsonSchema deserialized = mapper.readValue(serialized, McpSchema.JsonSchema.class); + + // Serialize one more time and compare with the first serialization + String serializedAgain = mapper.writeValueAsString(deserialized); + + // The two serialized strings should be the same + assertThatJson(serializedAgain).when(Option.IGNORING_ARRAY_ORDER).isEqualTo(json(serialized)); + } + @Test void testTool() throws Exception { String schemaJson = """ @@ -477,6 +564,48 @@ void testTool() throws Exception { {"name":"test-tool","description":"A test tool","inputSchema":{"type":"object","properties":{"name":{"type":"string"},"value":{"type":"number"}},"required":["name"]}}""")); } + @Test + void testToolWithComplexSchema() throws Exception { + String complexSchemaJson = """ + { + "type": "object", + "$defs": { + "Address": { + "type": "object", + "properties": { + "street": {"type": "string"}, + "city": {"type": "string"} + }, + "required": ["street", "city"] + } + }, + "properties": { + "name": {"type": "string"}, + "shippingAddress": {"$ref": "#/$defs/Address"} + }, + "required": ["name", "shippingAddress"] + } + """; + + McpSchema.Tool tool = new McpSchema.Tool("addressTool", "Handles addresses", complexSchemaJson); + + // Serialize the tool to a string + String serialized = mapper.writeValueAsString(tool); + + // Deserialize back to a Tool object + McpSchema.Tool deserializedTool = mapper.readValue(serialized, McpSchema.Tool.class); + + // Serialize again and compare with first serialization + String serializedAgain = mapper.writeValueAsString(deserializedTool); + + // The two serialized strings should be the same + assertThatJson(serializedAgain).when(Option.IGNORING_ARRAY_ORDER).isEqualTo(json(serialized)); + + // Just verify the basic structure was preserved + assertThat(deserializedTool.inputSchema().defs()).isNotNull(); + assertThat(deserializedTool.inputSchema().defs()).containsKey("Address"); + } + @Test void testCallToolRequest() throws Exception { Map arguments = new HashMap<>(); From 9c92a2b8bffe41f4c6df27ca1977bc8ee8343137 Mon Sep 17 00:00:00 2001 From: wangzhi <1277975348@qq.com> Date: Wed, 23 Apr 2025 23:03:10 +0800 Subject: [PATCH 041/205] Fix javadoc references and formatting (#149) --- .../server/AbstractMcpAsyncServerTests.java | 2 +- .../server/AbstractMcpSyncServerTests.java | 2 +- .../java/io/modelcontextprotocol/client/McpAsyncClient.java | 4 ++-- .../java/io/modelcontextprotocol/spec/McpServerSession.java | 3 ++- 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java index cdd43e7ef..025cfeacf 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java @@ -30,7 +30,7 @@ /** * Test suite for the {@link McpAsyncServer} that can be used with different - * {@link McpTransportProvider} implementations. + * {@link io.modelcontextprotocol.spec.McpServerTransportProvider} implementations. * * @author Christian Tzolov */ diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java index c81e638c1..e313454bd 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java @@ -27,7 +27,7 @@ /** * Test suite for the {@link McpSyncServer} that can be used with different - * {@link McpTransportProvider} implementations. + * {@link io.modelcontextprotocol.spec.McpServerTransportProvider} implementations. * * @author Christian Tzolov */ diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java index 2bc74f258..e3a997ba3 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java @@ -317,9 +317,9 @@ public Mono closeGracefully() { * The client MUST initiate this phase by sending an initialize request containing: * The protocol version the client supports, client's capabilities and clients * implementation information. - *

+ *

* The server MUST respond with its own capabilities and information. - *

+ *

* After successful initialization, the client MUST send an initialized notification * to indicate it is ready to begin normal operations. * @return the initialize result. diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java index 64315095b..86906d859 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java @@ -64,7 +64,8 @@ public class McpServerSession implements McpSession { * {@link io.modelcontextprotocol.spec.McpSchema.InitializeRequest} is received by the * server * @param initNotificationHandler called when a - * {@link McpSchema.METHOD_NOTIFICATION_INITIALIZED} is received. + * {@link io.modelcontextprotocol.spec.McpSchema#METHOD_NOTIFICATION_INITIALIZED} is + * received. * @param requestHandlers map of request handlers to use * @param notificationHandlers map of notification handlers to use */ From 261554bb7f1cc630aefeb5487434c1740a72b856 Mon Sep 17 00:00:00 2001 From: Francis Hodianto <61911161+FH-30@users.noreply.github.com> Date: Wed, 23 Apr 2025 23:09:41 +0800 Subject: [PATCH 042/205] fix: propagate Reactor Context into client transport chain (#154) --- .../java/io/modelcontextprotocol/spec/McpClientSession.java | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java index 6eca34757..f577b493a 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java @@ -230,18 +230,19 @@ private String generateRequestId() { public Mono sendRequest(String method, Object requestParams, TypeReference typeRef) { String requestId = this.generateRequestId(); - return Mono.create(sink -> { + return Mono.deferContextual(ctx -> Mono.create(sink -> { this.pendingResponses.put(requestId, sink); McpSchema.JSONRPCRequest jsonrpcRequest = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, method, requestId, requestParams); this.transport.sendMessage(jsonrpcRequest) + .contextWrite(ctx) // TODO: It's most efficient to create a dedicated Subscriber here .subscribe(v -> { }, error -> { this.pendingResponses.remove(requestId); sink.error(error); }); - }).timeout(this.requestTimeout).handle((jsonRpcResponse, sink) -> { + })).timeout(this.requestTimeout).handle((jsonRpcResponse, sink) -> { if (jsonRpcResponse.error() != null) { logger.error("Error handling request: {}", jsonRpcResponse.error()); sink.error(new McpError(jsonRpcResponse.error())); From e610d853f922e36ba474b2240f5c6546166e4840 Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Sun, 20 Apr 2025 10:58:51 +0300 Subject: [PATCH 043/205] feat: Add customizable URI template manager factory to MCP server Implement URI template functionality for MCP resources, allowing dynamic resource URIs with variables in the format {variableName}. - Enable resource URIs with variable placeholders (e.g., "/api/users/{userId}") - Automatic extraction of variable values from request URIs - Validation of template arguments in completions - Matching of request URIs against templates - Add new URI template management interfaces and implementations - Enhanced resource template listing to include templated resources - Updated resource request handling to support template matching - Test coverage for URI template functionality - Adding a configurable uriTemplateManagerFactory field to both AsyncSpecification and SyncSpecification classes - Adding builder methods to allow setting a custom URI template manager factory - Modifying constructors to pass the URI template manager factory to the server implementation - Updating the server implementation to use the provided factory - Add bulk registration methods for async completions Signed-off-by: Christian Tzolov --- .../WebFluxSseIntegrationTests.java | 3 +- .../server/McpAsyncServer.java | 77 +++++++-- .../server/McpServer.java | 68 +++++++- .../DeafaultMcpUriTemplateManagerFactory.java | 23 +++ .../util/DefaultMcpUriTemplateManager.java | 163 ++++++++++++++++++ .../util/McpUriTemplateManager.java | 52 ++++++ .../util/McpUriTemplateManagerFactory.java | 22 +++ .../McpUriTemplateManagerTests.java | 97 +++++++++++ 8 files changed, 489 insertions(+), 16 deletions(-) create mode 100644 mcp/src/main/java/io/modelcontextprotocol/util/DeafaultMcpUriTemplateManagerFactory.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/util/DefaultMcpUriTemplateManager.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/util/McpUriTemplateManager.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/util/McpUriTemplateManagerFactory.java create mode 100644 mcp/src/test/java/io/modelcontextprotocol/McpUriTemplateManagerTests.java diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java index 660f814da..2ba047461 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java @@ -776,7 +776,8 @@ void testCompletionShouldReturnExpectedSuggestions(String clientType) { var mcpServer = McpServer.sync(mcpServerTransportProvider) .capabilities(ServerCapabilities.builder().completions().build()) .prompts(new McpServerFeatures.SyncPromptSpecification( - new Prompt("code_review", "this is code review prompt", List.of()), + new Prompt("code_review", "this is code review prompt", + List.of(new PromptArgument("language", "string", false))), (mcpSyncServerExchange, getPromptRequest) -> null)) .completions(new McpServerFeatures.SyncCompletionSpecification( new McpSchema.PromptReference("ref/prompt", "code_review"), completionHandler)) diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java index 906cb9a08..3c112ad76 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -5,6 +5,7 @@ package io.modelcontextprotocol.server; import java.time.Duration; +import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -22,10 +23,13 @@ import io.modelcontextprotocol.spec.McpSchema.CallToolResult; import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; +import io.modelcontextprotocol.spec.McpSchema.ResourceTemplate; import io.modelcontextprotocol.spec.McpSchema.SetLevelRequest; import io.modelcontextprotocol.spec.McpSchema.Tool; import io.modelcontextprotocol.spec.McpServerSession; import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.util.DeafaultMcpUriTemplateManagerFactory; +import io.modelcontextprotocol.util.McpUriTemplateManagerFactory; import io.modelcontextprotocol.util.Utils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -92,8 +96,10 @@ public class McpAsyncServer { * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization */ McpAsyncServer(McpServerTransportProvider mcpTransportProvider, ObjectMapper objectMapper, - McpServerFeatures.Async features, Duration requestTimeout) { - this.delegate = new AsyncServerImpl(mcpTransportProvider, objectMapper, requestTimeout, features); + McpServerFeatures.Async features, Duration requestTimeout, + McpUriTemplateManagerFactory uriTemplateManagerFactory) { + this.delegate = new AsyncServerImpl(mcpTransportProvider, objectMapper, requestTimeout, features, + uriTemplateManagerFactory); } /** @@ -274,8 +280,11 @@ private static class AsyncServerImpl extends McpAsyncServer { private List protocolVersions = List.of(McpSchema.LATEST_PROTOCOL_VERSION); + private McpUriTemplateManagerFactory uriTemplateManagerFactory = new DeafaultMcpUriTemplateManagerFactory(); + AsyncServerImpl(McpServerTransportProvider mcpTransportProvider, ObjectMapper objectMapper, - Duration requestTimeout, McpServerFeatures.Async features) { + Duration requestTimeout, McpServerFeatures.Async features, + McpUriTemplateManagerFactory uriTemplateManagerFactory) { this.mcpTransportProvider = mcpTransportProvider; this.objectMapper = objectMapper; this.serverInfo = features.serverInfo(); @@ -286,6 +295,7 @@ private static class AsyncServerImpl extends McpAsyncServer { this.resourceTemplates.addAll(features.resourceTemplates()); this.prompts.putAll(features.prompts()); this.completions.putAll(features.completions()); + this.uriTemplateManagerFactory = uriTemplateManagerFactory; Map> requestHandlers = new HashMap<>(); @@ -564,8 +574,26 @@ private McpServerSession.RequestHandler resources private McpServerSession.RequestHandler resourceTemplateListRequestHandler() { return (exchange, params) -> Mono - .just(new McpSchema.ListResourceTemplatesResult(this.resourceTemplates, null)); + .just(new McpSchema.ListResourceTemplatesResult(this.getResourceTemplates(), null)); + + } + private List getResourceTemplates() { + var list = new ArrayList<>(this.resourceTemplates); + List resourceTemplates = this.resources.keySet() + .stream() + .filter(uri -> uri.contains("{")) + .map(uri -> { + var resource = this.resources.get(uri).resource(); + var template = new McpSchema.ResourceTemplate(resource.uri(), resource.name(), + resource.description(), resource.mimeType(), resource.annotations()); + return template; + }) + .toList(); + + list.addAll(resourceTemplates); + + return list; } private McpServerSession.RequestHandler resourcesReadRequestHandler() { @@ -574,11 +602,16 @@ private McpServerSession.RequestHandler resourcesR new TypeReference() { }); var resourceUri = resourceRequest.uri(); - McpServerFeatures.AsyncResourceSpecification specification = this.resources.get(resourceUri); - if (specification != null) { - return specification.readHandler().apply(exchange, resourceRequest); - } - return Mono.error(new McpError("Resource not found: " + resourceUri)); + + McpServerFeatures.AsyncResourceSpecification specification = this.resources.values() + .stream() + .filter(resourceSpecification -> this.uriTemplateManagerFactory + .create(resourceSpecification.resource().uri()) + .matches(resourceUri)) + .findFirst() + .orElseThrow(() -> new McpError("Resource not found: " + resourceUri)); + + return specification.readHandler().apply(exchange, resourceRequest); }; } @@ -729,20 +762,38 @@ private McpServerSession.RequestHandler completionComp String type = request.ref().type(); + String argumentName = request.argument().name(); + // check if the referenced resource exists if (type.equals("ref/prompt") && request.ref() instanceof McpSchema.PromptReference promptReference) { - McpServerFeatures.AsyncPromptSpecification prompt = this.prompts.get(promptReference.name()); - if (prompt == null) { + McpServerFeatures.AsyncPromptSpecification promptSpec = this.prompts.get(promptReference.name()); + if (promptSpec == null) { return Mono.error(new McpError("Prompt not found: " + promptReference.name())); } + if (!promptSpec.prompt() + .arguments() + .stream() + .filter(arg -> arg.name().equals(argumentName)) + .findFirst() + .isPresent()) { + + return Mono.error(new McpError("Argument not found: " + argumentName)); + } } if (type.equals("ref/resource") && request.ref() instanceof McpSchema.ResourceReference resourceReference) { - McpServerFeatures.AsyncResourceSpecification resource = this.resources.get(resourceReference.uri()); - if (resource == null) { + McpServerFeatures.AsyncResourceSpecification resourceSpec = this.resources + .get(resourceReference.uri()); + if (resourceSpec == null) { return Mono.error(new McpError("Resource not found: " + resourceReference.uri())); } + if (!uriTemplateManagerFactory.create(resourceSpec.resource().uri()) + .getVariableNames() + .contains(argumentName)) { + return Mono.error(new McpError("Argument not found: " + argumentName)); + } + } McpServerFeatures.AsyncCompletionSpecification specification = this.completions.get(request.ref()); diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java index 84089703c..d6ec2cc30 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java @@ -19,6 +19,8 @@ import io.modelcontextprotocol.spec.McpSchema.ResourceTemplate; import io.modelcontextprotocol.spec.McpServerTransportProvider; import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.DeafaultMcpUriTemplateManagerFactory; +import io.modelcontextprotocol.util.McpUriTemplateManagerFactory; import reactor.core.publisher.Mono; /** @@ -156,6 +158,8 @@ class AsyncSpecification { private final McpServerTransportProvider transportProvider; + private McpUriTemplateManagerFactory uriTemplateManagerFactory = new DeafaultMcpUriTemplateManagerFactory(); + private ObjectMapper objectMapper; private McpSchema.Implementation serverInfo = DEFAULT_SERVER_INFO; @@ -204,6 +208,19 @@ private AsyncSpecification(McpServerTransportProvider transportProvider) { this.transportProvider = transportProvider; } + /** + * Sets the URI template manager factory to use for creating URI templates. This + * allows for custom URI template parsing and variable extraction. + * @param uriTemplateManagerFactory The factory to use. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if uriTemplateManagerFactory is null + */ + public AsyncSpecification uriTemplateManagerFactory(McpUriTemplateManagerFactory uriTemplateManagerFactory) { + Assert.notNull(uriTemplateManagerFactory, "URI template manager factory must not be null"); + this.uriTemplateManagerFactory = uriTemplateManagerFactory; + return this; + } + /** * Sets the duration to wait for server responses before timing out requests. This * timeout applies to all requests made through the client, including tool calls, @@ -517,6 +534,36 @@ public AsyncSpecification prompts(McpServerFeatures.AsyncPromptSpecification... return this; } + /** + * Registers multiple completions with their handlers using a List. This method is + * useful when completions need to be added in bulk from a collection. + * @param completions List of completion specifications. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if completions is null + */ + public AsyncSpecification completions(List completions) { + Assert.notNull(completions, "Completions list must not be null"); + for (McpServerFeatures.AsyncCompletionSpecification completion : completions) { + this.completions.put(completion.referenceKey(), completion); + } + return this; + } + + /** + * Registers multiple completions with their handlers using varargs. This method + * is useful when completions are defined inline and added directly. + * @param completions Array of completion specifications. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if completions is null + */ + public AsyncSpecification completions(McpServerFeatures.AsyncCompletionSpecification... completions) { + Assert.notNull(completions, "Completions list must not be null"); + for (McpServerFeatures.AsyncCompletionSpecification completion : completions) { + this.completions.put(completion.referenceKey(), completion); + } + return this; + } + /** * Registers a consumer that will be notified when the list of roots changes. This * is useful for updating resource availability dynamically, such as when new @@ -587,7 +634,8 @@ public McpAsyncServer build() { this.resources, this.resourceTemplates, this.prompts, this.completions, this.rootsChangeHandlers, this.instructions); var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); - return new McpAsyncServer(this.transportProvider, mapper, features, this.requestTimeout); + return new McpAsyncServer(this.transportProvider, mapper, features, this.requestTimeout, + this.uriTemplateManagerFactory); } } @@ -600,6 +648,8 @@ class SyncSpecification { private static final McpSchema.Implementation DEFAULT_SERVER_INFO = new McpSchema.Implementation("mcp-server", "1.0.0"); + private McpUriTemplateManagerFactory uriTemplateManagerFactory = new DeafaultMcpUriTemplateManagerFactory(); + private final McpServerTransportProvider transportProvider; private ObjectMapper objectMapper; @@ -650,6 +700,19 @@ private SyncSpecification(McpServerTransportProvider transportProvider) { this.transportProvider = transportProvider; } + /** + * Sets the URI template manager factory to use for creating URI templates. This + * allows for custom URI template parsing and variable extraction. + * @param uriTemplateManagerFactory The factory to use. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if uriTemplateManagerFactory is null + */ + public SyncSpecification uriTemplateManagerFactory(McpUriTemplateManagerFactory uriTemplateManagerFactory) { + Assert.notNull(uriTemplateManagerFactory, "URI template manager factory must not be null"); + this.uriTemplateManagerFactory = uriTemplateManagerFactory; + return this; + } + /** * Sets the duration to wait for server responses before timing out requests. This * timeout applies to all requests made through the client, including tool calls, @@ -1064,7 +1127,8 @@ public McpSyncServer build() { this.rootsChangeHandlers, this.instructions); McpServerFeatures.Async asyncFeatures = McpServerFeatures.Async.fromSync(syncFeatures); var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); - var asyncServer = new McpAsyncServer(this.transportProvider, mapper, asyncFeatures, this.requestTimeout); + var asyncServer = new McpAsyncServer(this.transportProvider, mapper, asyncFeatures, this.requestTimeout, + this.uriTemplateManagerFactory); return new McpSyncServer(asyncServer); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/DeafaultMcpUriTemplateManagerFactory.java b/mcp/src/main/java/io/modelcontextprotocol/util/DeafaultMcpUriTemplateManagerFactory.java new file mode 100644 index 000000000..3870b76fc --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/util/DeafaultMcpUriTemplateManagerFactory.java @@ -0,0 +1,23 @@ +/* +* Copyright 2025 - 2025 the original author or authors. +*/ +package io.modelcontextprotocol.util; + +/** + * @author Christian Tzolov + */ +public class DeafaultMcpUriTemplateManagerFactory implements McpUriTemplateManagerFactory { + + /** + * Creates a new instance of {@link McpUriTemplateManager} with the specified URI + * template. + * @param uriTemplate The URI template to be used for variable extraction + * @return A new instance of {@link McpUriTemplateManager} + * @throws IllegalArgumentException if the URI template is null or empty + */ + @Override + public McpUriTemplateManager create(String uriTemplate) { + return new DefaultMcpUriTemplateManager(uriTemplate); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/DefaultMcpUriTemplateManager.java b/mcp/src/main/java/io/modelcontextprotocol/util/DefaultMcpUriTemplateManager.java new file mode 100644 index 000000000..b2e9a5285 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/util/DefaultMcpUriTemplateManager.java @@ -0,0 +1,163 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package io.modelcontextprotocol.util; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * Default implementation of the UriTemplateUtils interface. + *

+ * This class provides methods for extracting variables from URI templates and matching + * them against actual URIs. + * + * @author Christian Tzolov + */ +public class DefaultMcpUriTemplateManager implements McpUriTemplateManager { + + /** + * Pattern to match URI variables in the format {variableName}. + */ + private static final Pattern URI_VARIABLE_PATTERN = Pattern.compile("\\{([^/]+?)\\}"); + + private final String uriTemplate; + + /** + * Constructor for DefaultMcpUriTemplateManager. + * @param uriTemplate The URI template to be used for variable extraction + */ + public DefaultMcpUriTemplateManager(String uriTemplate) { + if (uriTemplate == null || uriTemplate.isEmpty()) { + throw new IllegalArgumentException("URI template must not be null or empty"); + } + this.uriTemplate = uriTemplate; + } + + /** + * Extract URI variable names from a URI template. + * @param uriTemplate The URI template containing variables in the format + * {variableName} + * @return A list of variable names extracted from the template + * @throws IllegalArgumentException if duplicate variable names are found + */ + @Override + public List getVariableNames() { + if (uriTemplate == null || uriTemplate.isEmpty()) { + return List.of(); + } + + List variables = new ArrayList<>(); + Matcher matcher = URI_VARIABLE_PATTERN.matcher(this.uriTemplate); + + while (matcher.find()) { + String variableName = matcher.group(1); + if (variables.contains(variableName)) { + throw new IllegalArgumentException("Duplicate URI variable name in template: " + variableName); + } + variables.add(variableName); + } + + return variables; + } + + /** + * Extract URI variable values from the actual request URI. + *

+ * This method converts the URI template into a regex pattern, then uses that pattern + * to extract variable values from the request URI. + * @param requestUri The actual URI from the request + * @return A map of variable names to their values + * @throws IllegalArgumentException if the URI template is invalid or the request URI + * doesn't match the template pattern + */ + @Override + public Map extractVariableValues(String requestUri) { + Map variableValues = new HashMap<>(); + List uriVariables = this.getVariableNames(); + + if (requestUri == null || uriVariables.isEmpty()) { + return variableValues; + } + + try { + // Create a regex pattern by replacing each {variableName} with a capturing + // group + StringBuilder patternBuilder = new StringBuilder("^"); + + // Find all variable placeholders and their positions + Matcher variableMatcher = URI_VARIABLE_PATTERN.matcher(uriTemplate); + int lastEnd = 0; + + while (variableMatcher.find()) { + // Add the text between the last variable and this one, escaped for regex + String textBefore = uriTemplate.substring(lastEnd, variableMatcher.start()); + patternBuilder.append(Pattern.quote(textBefore)); + + // Add a capturing group for the variable + patternBuilder.append("([^/]+)"); + + lastEnd = variableMatcher.end(); + } + + // Add any remaining text after the last variable + if (lastEnd < uriTemplate.length()) { + patternBuilder.append(Pattern.quote(uriTemplate.substring(lastEnd))); + } + + patternBuilder.append("$"); + + // Compile the pattern and match against the request URI + Pattern pattern = Pattern.compile(patternBuilder.toString()); + Matcher matcher = pattern.matcher(requestUri); + + if (matcher.find() && matcher.groupCount() == uriVariables.size()) { + for (int i = 0; i < uriVariables.size(); i++) { + String value = matcher.group(i + 1); + if (value == null || value.isEmpty()) { + throw new IllegalArgumentException( + "Empty value for URI variable '" + uriVariables.get(i) + "' in URI: " + requestUri); + } + variableValues.put(uriVariables.get(i), value); + } + } + } + catch (Exception e) { + throw new IllegalArgumentException("Error parsing URI template: " + uriTemplate + " for URI: " + requestUri, + e); + } + + return variableValues; + } + + /** + * Check if a URI matches the uriTemplate with variables. + * @param uri The URI to check + * @return true if the URI matches the pattern, false otherwise + */ + @Override + public boolean matches(String uri) { + // If the uriTemplate doesn't contain variables, do a direct comparison + if (!this.isUriTemplate(this.uriTemplate)) { + return uri.equals(this.uriTemplate); + } + + // Convert the pattern to a regex + String regex = this.uriTemplate.replaceAll("\\{[^/]+?\\}", "([^/]+?)"); + regex = regex.replace("/", "\\/"); + + // Check if the URI matches the regex + return Pattern.compile(regex).matcher(uri).matches(); + } + + @Override + public boolean isUriTemplate(String uri) { + return URI_VARIABLE_PATTERN.matcher(uri).find(); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/McpUriTemplateManager.java b/mcp/src/main/java/io/modelcontextprotocol/util/McpUriTemplateManager.java new file mode 100644 index 000000000..19569e49f --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/util/McpUriTemplateManager.java @@ -0,0 +1,52 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package io.modelcontextprotocol.util; + +import java.util.List; +import java.util.Map; + +/** + * Interface for working with URI templates. + *

+ * This interface provides methods for extracting variables from URI templates and + * matching them against actual URIs. + * + * @author Christian Tzolov + */ +public interface McpUriTemplateManager { + + /** + * Extract URI variable names from this URI template. + * @return A list of variable names extracted from the template + * @throws IllegalArgumentException if duplicate variable names are found + */ + List getVariableNames(); + + /** + * Extract URI variable values from the actual request URI. + *

+ * This method converts the URI template into a regex pattern, then uses that pattern + * to extract variable values from the request URI. + * @param uri The actual URI from the request + * @return A map of variable names to their values + * @throws IllegalArgumentException if the URI template is invalid or the request URI + * doesn't match the template pattern + */ + Map extractVariableValues(String uri); + + /** + * Indicate whether the given URI matches this template. + * @param uri the URI to match to + * @return {@code true} if it matches; {@code false} otherwise + */ + boolean matches(String uri); + + /** + * Check if the given URI is a URI template. + * @return Returns true if the URI contains variables in the format {variableName} + */ + public boolean isUriTemplate(String uri); + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/McpUriTemplateManagerFactory.java b/mcp/src/main/java/io/modelcontextprotocol/util/McpUriTemplateManagerFactory.java new file mode 100644 index 000000000..9644f9a6c --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/util/McpUriTemplateManagerFactory.java @@ -0,0 +1,22 @@ +/* +* Copyright 2025 - 2025 the original author or authors. +*/ +package io.modelcontextprotocol.util; + +/** + * Factory interface for creating instances of {@link McpUriTemplateManager}. + * + * @author Christian Tzolov + */ +public interface McpUriTemplateManagerFactory { + + /** + * Creates a new instance of {@link McpUriTemplateManager} with the specified URI + * template. + * @param uriTemplate The URI template to be used for variable extraction + * @return A new instance of {@link McpUriTemplateManager} + * @throws IllegalArgumentException if the URI template is null or empty + */ + McpUriTemplateManager create(String uriTemplate); + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/McpUriTemplateManagerTests.java b/mcp/src/test/java/io/modelcontextprotocol/McpUriTemplateManagerTests.java new file mode 100644 index 000000000..6f041daa6 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/McpUriTemplateManagerTests.java @@ -0,0 +1,97 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package io.modelcontextprotocol; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.List; +import java.util.Map; + +import io.modelcontextprotocol.util.DeafaultMcpUriTemplateManagerFactory; +import io.modelcontextprotocol.util.McpUriTemplateManager; +import io.modelcontextprotocol.util.McpUriTemplateManagerFactory; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +/** + * Tests for {@link McpUriTemplateManager} and its implementations. + * + * @author Christian Tzolov + */ +public class McpUriTemplateManagerTests { + + private McpUriTemplateManagerFactory uriTemplateFactory; + + @BeforeEach + void setUp() { + this.uriTemplateFactory = new DeafaultMcpUriTemplateManagerFactory(); + } + + @Test + void shouldExtractVariableNamesFromTemplate() { + List variables = this.uriTemplateFactory.create("/api/users/{userId}/posts/{postId}") + .getVariableNames(); + assertEquals(2, variables.size()); + assertEquals("userId", variables.get(0)); + assertEquals("postId", variables.get(1)); + } + + @Test + void shouldReturnEmptyListWhenTemplateHasNoVariables() { + List variables = this.uriTemplateFactory.create("/api/users/all").getVariableNames(); + assertEquals(0, variables.size()); + } + + @Test + void shouldThrowExceptionWhenExtractingVariablesFromNullTemplate() { + assertThrows(IllegalArgumentException.class, () -> this.uriTemplateFactory.create(null).getVariableNames()); + } + + @Test + void shouldThrowExceptionWhenExtractingVariablesFromEmptyTemplate() { + assertThrows(IllegalArgumentException.class, () -> this.uriTemplateFactory.create("").getVariableNames()); + } + + @Test + void shouldThrowExceptionWhenTemplateContainsDuplicateVariables() { + assertThrows(IllegalArgumentException.class, + () -> this.uriTemplateFactory.create("/api/users/{userId}/posts/{userId}").getVariableNames()); + } + + @Test + void shouldExtractVariableValuesFromRequestUri() { + Map values = this.uriTemplateFactory.create("/api/users/{userId}/posts/{postId}") + .extractVariableValues("/api/users/123/posts/456"); + assertEquals(2, values.size()); + assertEquals("123", values.get("userId")); + assertEquals("456", values.get("postId")); + } + + @Test + void shouldReturnEmptyMapWhenTemplateHasNoVariables() { + Map values = this.uriTemplateFactory.create("/api/users/all") + .extractVariableValues("/api/users/all"); + assertEquals(0, values.size()); + } + + @Test + void shouldReturnEmptyMapWhenRequestUriIsNull() { + Map values = this.uriTemplateFactory.create("/api/users/{userId}/posts/{postId}") + .extractVariableValues(null); + assertEquals(0, values.size()); + } + + @Test + void shouldMatchUriAgainstTemplatePattern() { + var uriTemplateManager = this.uriTemplateFactory.create("/api/users/{userId}/posts/{postId}"); + + assertTrue(uriTemplateManager.matches("/api/users/123/posts/456")); + assertFalse(uriTemplateManager.matches("/api/users/123/comments/456")); + } + +} From e34babbe56b730514d35191118d5e66bc9c51b9a Mon Sep 17 00:00:00 2001 From: jito Date: Thu, 8 May 2025 18:31:17 +0900 Subject: [PATCH 044/205] Add missing isInitialized method to McpSyncClient (#181) The isInitialized method is present in McpAsyncClient and needs to be mirrored in McpSyncClient. Signed-off-by: jitokim --- .../io/modelcontextprotocol/client/McpSyncClient.java | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java index c91638a7e..a8fb979e1 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java @@ -97,6 +97,14 @@ public McpSchema.Implementation getServerInfo() { return this.delegate.getServerInfo(); } + /** + * Check if the client-server connection is initialized. + * @return true if the client-server connection is initialized + */ + public boolean isInitialized() { + return this.delegate.isInitialized(); + } + /** * Get the client capabilities that define the supported features and functionality. * @return The client capabilities From eae3840e7d44932c60c131cb7a346b5367b788ff Mon Sep 17 00:00:00 2001 From: Dennis Kawurek Date: Fri, 25 Apr 2025 18:47:54 +0200 Subject: [PATCH 045/205] fix: Mockito inline mocking for Java 21+ (#207) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Before this fix the execution of the maven surefire plugin with Java 21 logged warnings that mockito should be added as a java agent, because the self-attaching won't be supported in future java releases. In Java 24 the test just broke. This problem is solved by modifying the pom.xml of the parent and doing this changes: * Adding mockito as a java agent. * Removing the surefireArgLine from the properties. This can be added back when it's needed (for example when JaCoCo will be used). Furthermore, the pom.xml in the mcp-spring-* modules now have the byte-buddy dependency included, as the test would otherwise break when trying to mock McpSchema#CreateMessageRequest. Fixes #187 Co-authored-by: Dariusz Jędrzejczyk --- mcp-spring/mcp-spring-webflux/pom.xml | 6 ++++++ mcp-spring/mcp-spring-webmvc/pom.xml | 6 ++++++ pom.xml | 15 +++++++++++++-- 3 files changed, 25 insertions(+), 2 deletions(-) diff --git a/mcp-spring/mcp-spring-webflux/pom.xml b/mcp-spring/mcp-spring-webflux/pom.xml index 63c32a8a8..86f46bf95 100644 --- a/mcp-spring/mcp-spring-webflux/pom.xml +++ b/mcp-spring/mcp-spring-webflux/pom.xml @@ -82,6 +82,12 @@ ${mockito.version} test + + net.bytebuddy + byte-buddy + ${byte-buddy.version} + test + io.projectreactor reactor-test diff --git a/mcp-spring/mcp-spring-webmvc/pom.xml b/mcp-spring/mcp-spring-webmvc/pom.xml index b59be6a03..82fbbf3e6 100644 --- a/mcp-spring/mcp-spring-webmvc/pom.xml +++ b/mcp-spring/mcp-spring-webmvc/pom.xml @@ -77,6 +77,12 @@ ${mockito.version} test + + net.bytebuddy + byte-buddy + ${byte-buddy.version} + test + org.testcontainers junit-jupiter diff --git a/pom.xml b/pom.xml index 9be256ccf..638457406 100644 --- a/pom.xml +++ b/pom.xml @@ -57,6 +57,7 @@ 17 17 17 + 3.26.3 5.10.2 @@ -163,13 +164,23 @@ + + org.apache.maven.plugins + maven-dependency-plugin + + + + properties + + + + org.apache.maven.plugins maven-surefire-plugin ${maven-surefire-plugin.version} - ${surefireArgLine} - + ${surefireArgLine} -javaagent:${org.mockito:mockito-core:jar} false false From 0069c977ef88b91162b08899bb8040a0ffcb8653 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Fri, 9 May 2025 12:57:36 +0200 Subject: [PATCH 046/205] Remove temporary delegate impl from McpAsyncServer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Dariusz Jędrzejczyk --- .../server/McpAsyncServer.java | 1082 ++++++++--------- 1 file changed, 484 insertions(+), 598 deletions(-) diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java index 3c112ad76..1efa13de3 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -82,11 +82,33 @@ public class McpAsyncServer { private static final Logger logger = LoggerFactory.getLogger(McpAsyncServer.class); - private final McpAsyncServer delegate; + private final McpServerTransportProvider mcpTransportProvider; - McpAsyncServer() { - this.delegate = null; - } + private final ObjectMapper objectMapper; + + private final McpSchema.ServerCapabilities serverCapabilities; + + private final McpSchema.Implementation serverInfo; + + private final String instructions; + + private final CopyOnWriteArrayList tools = new CopyOnWriteArrayList<>(); + + private final CopyOnWriteArrayList resourceTemplates = new CopyOnWriteArrayList<>(); + + private final ConcurrentHashMap resources = new ConcurrentHashMap<>(); + + private final ConcurrentHashMap prompts = new ConcurrentHashMap<>(); + + // FIXME: this field is deprecated and should be remvoed together with the + // broadcasting loggingNotification. + private LoggingLevel minLoggingLevel = LoggingLevel.DEBUG; + + private final ConcurrentHashMap completions = new ConcurrentHashMap<>(); + + private List protocolVersions = List.of(McpSchema.LATEST_PROTOCOL_VERSION); + + private McpUriTemplateManagerFactory uriTemplateManagerFactory = new DeafaultMcpUriTemplateManagerFactory(); /** * Create a new McpAsyncServer with the given transport provider and capabilities. @@ -98,8 +120,104 @@ public class McpAsyncServer { McpAsyncServer(McpServerTransportProvider mcpTransportProvider, ObjectMapper objectMapper, McpServerFeatures.Async features, Duration requestTimeout, McpUriTemplateManagerFactory uriTemplateManagerFactory) { - this.delegate = new AsyncServerImpl(mcpTransportProvider, objectMapper, requestTimeout, features, - uriTemplateManagerFactory); + this.mcpTransportProvider = mcpTransportProvider; + this.objectMapper = objectMapper; + this.serverInfo = features.serverInfo(); + this.serverCapabilities = features.serverCapabilities(); + this.instructions = features.instructions(); + this.tools.addAll(features.tools()); + this.resources.putAll(features.resources()); + this.resourceTemplates.addAll(features.resourceTemplates()); + this.prompts.putAll(features.prompts()); + this.completions.putAll(features.completions()); + this.uriTemplateManagerFactory = uriTemplateManagerFactory; + + Map> requestHandlers = new HashMap<>(); + + // Initialize request handlers for standard MCP methods + + // Ping MUST respond with an empty data, but not NULL response. + requestHandlers.put(McpSchema.METHOD_PING, (exchange, params) -> Mono.just(Map.of())); + + // Add tools API handlers if the tool capability is enabled + if (this.serverCapabilities.tools() != null) { + requestHandlers.put(McpSchema.METHOD_TOOLS_LIST, toolsListRequestHandler()); + requestHandlers.put(McpSchema.METHOD_TOOLS_CALL, toolsCallRequestHandler()); + } + + // Add resources API handlers if provided + if (this.serverCapabilities.resources() != null) { + requestHandlers.put(McpSchema.METHOD_RESOURCES_LIST, resourcesListRequestHandler()); + requestHandlers.put(McpSchema.METHOD_RESOURCES_READ, resourcesReadRequestHandler()); + requestHandlers.put(McpSchema.METHOD_RESOURCES_TEMPLATES_LIST, resourceTemplateListRequestHandler()); + } + + // Add prompts API handlers if provider exists + if (this.serverCapabilities.prompts() != null) { + requestHandlers.put(McpSchema.METHOD_PROMPT_LIST, promptsListRequestHandler()); + requestHandlers.put(McpSchema.METHOD_PROMPT_GET, promptsGetRequestHandler()); + } + + // Add logging API handlers if the logging capability is enabled + if (this.serverCapabilities.logging() != null) { + requestHandlers.put(McpSchema.METHOD_LOGGING_SET_LEVEL, setLoggerRequestHandler()); + } + + // Add completion API handlers if the completion capability is enabled + if (this.serverCapabilities.completions() != null) { + requestHandlers.put(McpSchema.METHOD_COMPLETION_COMPLETE, completionCompleteRequestHandler()); + } + + Map notificationHandlers = new HashMap<>(); + + notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_INITIALIZED, (exchange, params) -> Mono.empty()); + + List, Mono>> rootsChangeConsumers = features + .rootsChangeConsumers(); + + if (Utils.isEmpty(rootsChangeConsumers)) { + rootsChangeConsumers = List.of((exchange, roots) -> Mono.fromRunnable(() -> logger + .warn("Roots list changed notification, but no consumers provided. Roots list changed: {}", roots))); + } + + notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED, + asyncRootsListChangedNotificationHandler(rootsChangeConsumers)); + + mcpTransportProvider.setSessionFactory( + transport -> new McpServerSession(UUID.randomUUID().toString(), requestTimeout, transport, + this::asyncInitializeRequestHandler, Mono::empty, requestHandlers, notificationHandlers)); + } + + // --------------------------------------- + // Lifecycle Management + // --------------------------------------- + private Mono asyncInitializeRequestHandler( + McpSchema.InitializeRequest initializeRequest) { + return Mono.defer(() -> { + logger.info("Client initialize request - Protocol: {}, Capabilities: {}, Info: {}", + initializeRequest.protocolVersion(), initializeRequest.capabilities(), + initializeRequest.clientInfo()); + + // The server MUST respond with the highest protocol version it supports + // if + // it does not support the requested (e.g. Client) version. + String serverProtocolVersion = this.protocolVersions.get(this.protocolVersions.size() - 1); + + if (this.protocolVersions.contains(initializeRequest.protocolVersion())) { + // If the server supports the requested protocol version, it MUST + // respond + // with the same version. + serverProtocolVersion = initializeRequest.protocolVersion(); + } + else { + logger.warn( + "Client requested unsupported protocol version: {}, so the server will suggest the {} version instead", + initializeRequest.protocolVersion(), serverProtocolVersion); + } + + return Mono.just(new McpSchema.InitializeResult(serverProtocolVersion, this.serverCapabilities, + this.serverInfo, this.instructions)); + }); } /** @@ -107,7 +225,7 @@ public class McpAsyncServer { * @return The server capabilities */ public McpSchema.ServerCapabilities getServerCapabilities() { - return this.delegate.getServerCapabilities(); + return this.serverCapabilities; } /** @@ -115,7 +233,7 @@ public McpSchema.ServerCapabilities getServerCapabilities() { * @return The server implementation details */ public McpSchema.Implementation getServerInfo() { - return this.delegate.getServerInfo(); + return this.serverInfo; } /** @@ -123,26 +241,66 @@ public McpSchema.Implementation getServerInfo() { * @return A Mono that completes when the server has been closed */ public Mono closeGracefully() { - return this.delegate.closeGracefully(); + return this.mcpTransportProvider.closeGracefully(); } /** * Close the server immediately. */ public void close() { - this.delegate.close(); + this.mcpTransportProvider.close(); + } + + private McpServerSession.NotificationHandler asyncRootsListChangedNotificationHandler( + List, Mono>> rootsChangeConsumers) { + return (exchange, params) -> exchange.listRoots() + .flatMap(listRootsResult -> Flux.fromIterable(rootsChangeConsumers) + .flatMap(consumer -> consumer.apply(exchange, listRootsResult.roots())) + .onErrorResume(error -> { + logger.error("Error handling roots list change notification", error); + return Mono.empty(); + }) + .then()); } // --------------------------------------- // Tool Management // --------------------------------------- + /** * Add a new tool specification at runtime. * @param toolSpecification The tool specification to add * @return Mono that completes when clients have been notified of the change */ public Mono addTool(McpServerFeatures.AsyncToolSpecification toolSpecification) { - return this.delegate.addTool(toolSpecification); + if (toolSpecification == null) { + return Mono.error(new McpError("Tool specification must not be null")); + } + if (toolSpecification.tool() == null) { + return Mono.error(new McpError("Tool must not be null")); + } + if (toolSpecification.call() == null) { + return Mono.error(new McpError("Tool call handler must not be null")); + } + if (this.serverCapabilities.tools() == null) { + return Mono.error(new McpError("Server must be configured with tool capabilities")); + } + + return Mono.defer(() -> { + // Check for duplicate tool names + if (this.tools.stream().anyMatch(th -> th.tool().name().equals(toolSpecification.tool().name()))) { + return Mono + .error(new McpError("Tool with name '" + toolSpecification.tool().name() + "' already exists")); + } + + this.tools.add(toolSpecification); + logger.debug("Added tool handler: {}", toolSpecification.tool().name()); + + if (this.serverCapabilities.tools().listChanged()) { + return notifyToolsListChanged(); + } + return Mono.empty(); + }); } /** @@ -151,7 +309,25 @@ public Mono addTool(McpServerFeatures.AsyncToolSpecification toolSpecifica * @return Mono that completes when clients have been notified of the change */ public Mono removeTool(String toolName) { - return this.delegate.removeTool(toolName); + if (toolName == null) { + return Mono.error(new McpError("Tool name must not be null")); + } + if (this.serverCapabilities.tools() == null) { + return Mono.error(new McpError("Server must be configured with tool capabilities")); + } + + return Mono.defer(() -> { + boolean removed = this.tools + .removeIf(toolSpecification -> toolSpecification.tool().name().equals(toolName)); + if (removed) { + logger.debug("Removed tool handler: {}", toolName); + if (this.serverCapabilities.tools().listChanged()) { + return notifyToolsListChanged(); + } + return Mono.empty(); + } + return Mono.error(new McpError("Tool with name '" + toolName + "' not found")); + }); } /** @@ -159,19 +335,65 @@ public Mono removeTool(String toolName) { * @return A Mono that completes when all clients have been notified */ public Mono notifyToolsListChanged() { - return this.delegate.notifyToolsListChanged(); + return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_TOOLS_LIST_CHANGED, null); + } + + private McpServerSession.RequestHandler toolsListRequestHandler() { + return (exchange, params) -> { + List tools = this.tools.stream().map(McpServerFeatures.AsyncToolSpecification::tool).toList(); + + return Mono.just(new McpSchema.ListToolsResult(tools, null)); + }; + } + + private McpServerSession.RequestHandler toolsCallRequestHandler() { + return (exchange, params) -> { + McpSchema.CallToolRequest callToolRequest = objectMapper.convertValue(params, + new TypeReference() { + }); + + Optional toolSpecification = this.tools.stream() + .filter(tr -> callToolRequest.name().equals(tr.tool().name())) + .findAny(); + + if (toolSpecification.isEmpty()) { + return Mono.error(new McpError("Tool not found: " + callToolRequest.name())); + } + + return toolSpecification.map(tool -> tool.call().apply(exchange, callToolRequest.arguments())) + .orElse(Mono.error(new McpError("Tool not found: " + callToolRequest.name()))); + }; } // --------------------------------------- // Resource Management // --------------------------------------- + /** * Add a new resource handler at runtime. - * @param resourceHandler The resource handler to add + * @param resourceSpecification The resource handler to add * @return Mono that completes when clients have been notified of the change */ - public Mono addResource(McpServerFeatures.AsyncResourceSpecification resourceHandler) { - return this.delegate.addResource(resourceHandler); + public Mono addResource(McpServerFeatures.AsyncResourceSpecification resourceSpecification) { + if (resourceSpecification == null || resourceSpecification.resource() == null) { + return Mono.error(new McpError("Resource must not be null")); + } + + if (this.serverCapabilities.resources() == null) { + return Mono.error(new McpError("Server must be configured with resource capabilities")); + } + + return Mono.defer(() -> { + if (this.resources.putIfAbsent(resourceSpecification.resource().uri(), resourceSpecification) != null) { + return Mono.error(new McpError( + "Resource with URI '" + resourceSpecification.resource().uri() + "' already exists")); + } + logger.debug("Added resource handler: {}", resourceSpecification.resource().uri()); + if (this.serverCapabilities.resources().listChanged()) { + return notifyResourcesListChanged(); + } + return Mono.empty(); + }); } /** @@ -180,7 +402,24 @@ public Mono addResource(McpServerFeatures.AsyncResourceSpecification resou * @return Mono that completes when clients have been notified of the change */ public Mono removeResource(String resourceUri) { - return this.delegate.removeResource(resourceUri); + if (resourceUri == null) { + return Mono.error(new McpError("Resource URI must not be null")); + } + if (this.serverCapabilities.resources() == null) { + return Mono.error(new McpError("Server must be configured with resource capabilities")); + } + + return Mono.defer(() -> { + McpServerFeatures.AsyncResourceSpecification removed = this.resources.remove(resourceUri); + if (removed != null) { + logger.debug("Removed resource handler: {}", resourceUri); + if (this.serverCapabilities.resources().listChanged()) { + return notifyResourcesListChanged(); + } + return Mono.empty(); + } + return Mono.error(new McpError("Resource with URI '" + resourceUri + "' not found")); + }); } /** @@ -188,19 +427,97 @@ public Mono removeResource(String resourceUri) { * @return A Mono that completes when all clients have been notified */ public Mono notifyResourcesListChanged() { - return this.delegate.notifyResourcesListChanged(); + return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_RESOURCES_LIST_CHANGED, null); + } + + private McpServerSession.RequestHandler resourcesListRequestHandler() { + return (exchange, params) -> { + var resourceList = this.resources.values() + .stream() + .map(McpServerFeatures.AsyncResourceSpecification::resource) + .toList(); + return Mono.just(new McpSchema.ListResourcesResult(resourceList, null)); + }; + } + + private McpServerSession.RequestHandler resourceTemplateListRequestHandler() { + return (exchange, params) -> Mono + .just(new McpSchema.ListResourceTemplatesResult(this.getResourceTemplates(), null)); + + } + + private List getResourceTemplates() { + var list = new ArrayList<>(this.resourceTemplates); + List resourceTemplates = this.resources.keySet() + .stream() + .filter(uri -> uri.contains("{")) + .map(uri -> { + var resource = this.resources.get(uri).resource(); + var template = new McpSchema.ResourceTemplate(resource.uri(), resource.name(), resource.description(), + resource.mimeType(), resource.annotations()); + return template; + }) + .toList(); + + list.addAll(resourceTemplates); + + return list; + } + + private McpServerSession.RequestHandler resourcesReadRequestHandler() { + return (exchange, params) -> { + McpSchema.ReadResourceRequest resourceRequest = objectMapper.convertValue(params, + new TypeReference() { + }); + var resourceUri = resourceRequest.uri(); + + McpServerFeatures.AsyncResourceSpecification specification = this.resources.values() + .stream() + .filter(resourceSpecification -> this.uriTemplateManagerFactory + .create(resourceSpecification.resource().uri()) + .matches(resourceUri)) + .findFirst() + .orElseThrow(() -> new McpError("Resource not found: " + resourceUri)); + + return specification.readHandler().apply(exchange, resourceRequest); + }; } // --------------------------------------- // Prompt Management // --------------------------------------- + /** * Add a new prompt handler at runtime. * @param promptSpecification The prompt handler to add * @return Mono that completes when clients have been notified of the change */ public Mono addPrompt(McpServerFeatures.AsyncPromptSpecification promptSpecification) { - return this.delegate.addPrompt(promptSpecification); + if (promptSpecification == null) { + return Mono.error(new McpError("Prompt specification must not be null")); + } + if (this.serverCapabilities.prompts() == null) { + return Mono.error(new McpError("Server must be configured with prompt capabilities")); + } + + return Mono.defer(() -> { + McpServerFeatures.AsyncPromptSpecification specification = this.prompts + .putIfAbsent(promptSpecification.prompt().name(), promptSpecification); + if (specification != null) { + return Mono.error( + new McpError("Prompt with name '" + promptSpecification.prompt().name() + "' already exists")); + } + + logger.debug("Added prompt handler: {}", promptSpecification.prompt().name()); + + // Servers that declared the listChanged capability SHOULD send a + // notification, + // when the list of available prompts changes + if (this.serverCapabilities.prompts().listChanged()) { + return notifyPromptsListChanged(); + } + return Mono.empty(); + }); } /** @@ -209,7 +526,27 @@ public Mono addPrompt(McpServerFeatures.AsyncPromptSpecification promptSpe * @return Mono that completes when clients have been notified of the change */ public Mono removePrompt(String promptName) { - return this.delegate.removePrompt(promptName); + if (promptName == null) { + return Mono.error(new McpError("Prompt name must not be null")); + } + if (this.serverCapabilities.prompts() == null) { + return Mono.error(new McpError("Server must be configured with prompt capabilities")); + } + + return Mono.defer(() -> { + McpServerFeatures.AsyncPromptSpecification removed = this.prompts.remove(promptName); + + if (removed != null) { + logger.debug("Removed prompt handler: {}", promptName); + // Servers that declared the listChanged capability SHOULD send a + // notification, when the list of available prompts changes + if (this.serverCapabilities.prompts().listChanged()) { + return this.notifyPromptsListChanged(); + } + return Mono.empty(); + } + return Mono.error(new McpError("Prompt with name '" + promptName + "' not found")); + }); } /** @@ -217,7 +554,39 @@ public Mono removePrompt(String promptName) { * @return A Mono that completes when all clients have been notified */ public Mono notifyPromptsListChanged() { - return this.delegate.notifyPromptsListChanged(); + return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_PROMPTS_LIST_CHANGED, null); + } + + private McpServerSession.RequestHandler promptsListRequestHandler() { + return (exchange, params) -> { + // TODO: Implement pagination + // McpSchema.PaginatedRequest request = objectMapper.convertValue(params, + // new TypeReference() { + // }); + + var promptList = this.prompts.values() + .stream() + .map(McpServerFeatures.AsyncPromptSpecification::prompt) + .toList(); + + return Mono.just(new McpSchema.ListPromptsResult(promptList, null)); + }; + } + + private McpServerSession.RequestHandler promptsGetRequestHandler() { + return (exchange, params) -> { + McpSchema.GetPromptRequest promptRequest = objectMapper.convertValue(params, + new TypeReference() { + }); + + // Implement prompt retrieval logic here + McpServerFeatures.AsyncPromptSpecification specification = this.prompts.get(promptRequest.name()); + if (specification == null) { + return Mono.error(new McpError("Prompt not found: " + promptRequest.name())); + } + + return specification.promptHandler().apply(exchange, promptRequest); + }; } // --------------------------------------- @@ -237,619 +606,136 @@ public Mono notifyPromptsListChanged() { */ @Deprecated public Mono loggingNotification(LoggingMessageNotification loggingMessageNotification) { - return this.delegate.loggingNotification(loggingMessageNotification); - } - - // --------------------------------------- - // Sampling - // --------------------------------------- - /** - * This method is package-private and used for test only. Should not be called by user - * code. - * @param protocolVersions the Client supported protocol versions. - */ - void setProtocolVersions(List protocolVersions) { - this.delegate.setProtocolVersions(protocolVersions); - } - - private static class AsyncServerImpl extends McpAsyncServer { - - private final McpServerTransportProvider mcpTransportProvider; - - private final ObjectMapper objectMapper; - - private final McpSchema.ServerCapabilities serverCapabilities; - - private final McpSchema.Implementation serverInfo; - - private final String instructions; - - private final CopyOnWriteArrayList tools = new CopyOnWriteArrayList<>(); - - private final CopyOnWriteArrayList resourceTemplates = new CopyOnWriteArrayList<>(); - - private final ConcurrentHashMap resources = new ConcurrentHashMap<>(); - - private final ConcurrentHashMap prompts = new ConcurrentHashMap<>(); - - // FIXME: this field is deprecated and should be remvoed together with the - // broadcasting loggingNotification. - private LoggingLevel minLoggingLevel = LoggingLevel.DEBUG; - - private final ConcurrentHashMap completions = new ConcurrentHashMap<>(); - - private List protocolVersions = List.of(McpSchema.LATEST_PROTOCOL_VERSION); - - private McpUriTemplateManagerFactory uriTemplateManagerFactory = new DeafaultMcpUriTemplateManagerFactory(); - - AsyncServerImpl(McpServerTransportProvider mcpTransportProvider, ObjectMapper objectMapper, - Duration requestTimeout, McpServerFeatures.Async features, - McpUriTemplateManagerFactory uriTemplateManagerFactory) { - this.mcpTransportProvider = mcpTransportProvider; - this.objectMapper = objectMapper; - this.serverInfo = features.serverInfo(); - this.serverCapabilities = features.serverCapabilities(); - this.instructions = features.instructions(); - this.tools.addAll(features.tools()); - this.resources.putAll(features.resources()); - this.resourceTemplates.addAll(features.resourceTemplates()); - this.prompts.putAll(features.prompts()); - this.completions.putAll(features.completions()); - this.uriTemplateManagerFactory = uriTemplateManagerFactory; - - Map> requestHandlers = new HashMap<>(); - - // Initialize request handlers for standard MCP methods - - // Ping MUST respond with an empty data, but not NULL response. - requestHandlers.put(McpSchema.METHOD_PING, (exchange, params) -> Mono.just(Map.of())); - - // Add tools API handlers if the tool capability is enabled - if (this.serverCapabilities.tools() != null) { - requestHandlers.put(McpSchema.METHOD_TOOLS_LIST, toolsListRequestHandler()); - requestHandlers.put(McpSchema.METHOD_TOOLS_CALL, toolsCallRequestHandler()); - } - - // Add resources API handlers if provided - if (this.serverCapabilities.resources() != null) { - requestHandlers.put(McpSchema.METHOD_RESOURCES_LIST, resourcesListRequestHandler()); - requestHandlers.put(McpSchema.METHOD_RESOURCES_READ, resourcesReadRequestHandler()); - requestHandlers.put(McpSchema.METHOD_RESOURCES_TEMPLATES_LIST, resourceTemplateListRequestHandler()); - } - - // Add prompts API handlers if provider exists - if (this.serverCapabilities.prompts() != null) { - requestHandlers.put(McpSchema.METHOD_PROMPT_LIST, promptsListRequestHandler()); - requestHandlers.put(McpSchema.METHOD_PROMPT_GET, promptsGetRequestHandler()); - } - - // Add logging API handlers if the logging capability is enabled - if (this.serverCapabilities.logging() != null) { - requestHandlers.put(McpSchema.METHOD_LOGGING_SET_LEVEL, setLoggerRequestHandler()); - } - - // Add completion API handlers if the completion capability is enabled - if (this.serverCapabilities.completions() != null) { - requestHandlers.put(McpSchema.METHOD_COMPLETION_COMPLETE, completionCompleteRequestHandler()); - } - - Map notificationHandlers = new HashMap<>(); - - notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_INITIALIZED, (exchange, params) -> Mono.empty()); - - List, Mono>> rootsChangeConsumers = features - .rootsChangeConsumers(); - - if (Utils.isEmpty(rootsChangeConsumers)) { - rootsChangeConsumers = List.of((exchange, - roots) -> Mono.fromRunnable(() -> logger.warn( - "Roots list changed notification, but no consumers provided. Roots list changed: {}", - roots))); - } - - notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED, - asyncRootsListChangedNotificationHandler(rootsChangeConsumers)); - - mcpTransportProvider.setSessionFactory( - transport -> new McpServerSession(UUID.randomUUID().toString(), requestTimeout, transport, - this::asyncInitializeRequestHandler, Mono::empty, requestHandlers, notificationHandlers)); - } - - // --------------------------------------- - // Lifecycle Management - // --------------------------------------- - private Mono asyncInitializeRequestHandler( - McpSchema.InitializeRequest initializeRequest) { - return Mono.defer(() -> { - logger.info("Client initialize request - Protocol: {}, Capabilities: {}, Info: {}", - initializeRequest.protocolVersion(), initializeRequest.capabilities(), - initializeRequest.clientInfo()); - - // The server MUST respond with the highest protocol version it supports - // if - // it does not support the requested (e.g. Client) version. - String serverProtocolVersion = this.protocolVersions.get(this.protocolVersions.size() - 1); - - if (this.protocolVersions.contains(initializeRequest.protocolVersion())) { - // If the server supports the requested protocol version, it MUST - // respond - // with the same version. - serverProtocolVersion = initializeRequest.protocolVersion(); - } - else { - logger.warn( - "Client requested unsupported protocol version: {}, so the server will suggest the {} version instead", - initializeRequest.protocolVersion(), serverProtocolVersion); - } - - return Mono.just(new McpSchema.InitializeResult(serverProtocolVersion, this.serverCapabilities, - this.serverInfo, this.instructions)); - }); - } - - public McpSchema.ServerCapabilities getServerCapabilities() { - return this.serverCapabilities; - } - - public McpSchema.Implementation getServerInfo() { - return this.serverInfo; - } - @Override - public Mono closeGracefully() { - return this.mcpTransportProvider.closeGracefully(); + if (loggingMessageNotification == null) { + return Mono.error(new McpError("Logging message must not be null")); } - @Override - public void close() { - this.mcpTransportProvider.close(); + if (loggingMessageNotification.level().level() < minLoggingLevel.level()) { + return Mono.empty(); } - private McpServerSession.NotificationHandler asyncRootsListChangedNotificationHandler( - List, Mono>> rootsChangeConsumers) { - return (exchange, params) -> exchange.listRoots() - .flatMap(listRootsResult -> Flux.fromIterable(rootsChangeConsumers) - .flatMap(consumer -> consumer.apply(exchange, listRootsResult.roots())) - .onErrorResume(error -> { - logger.error("Error handling roots list change notification", error); - return Mono.empty(); - }) - .then()); - } - - // --------------------------------------- - // Tool Management - // --------------------------------------- - - @Override - public Mono addTool(McpServerFeatures.AsyncToolSpecification toolSpecification) { - if (toolSpecification == null) { - return Mono.error(new McpError("Tool specification must not be null")); - } - if (toolSpecification.tool() == null) { - return Mono.error(new McpError("Tool must not be null")); - } - if (toolSpecification.call() == null) { - return Mono.error(new McpError("Tool call handler must not be null")); - } - if (this.serverCapabilities.tools() == null) { - return Mono.error(new McpError("Server must be configured with tool capabilities")); - } - - return Mono.defer(() -> { - // Check for duplicate tool names - if (this.tools.stream().anyMatch(th -> th.tool().name().equals(toolSpecification.tool().name()))) { - return Mono - .error(new McpError("Tool with name '" + toolSpecification.tool().name() + "' already exists")); - } - - this.tools.add(toolSpecification); - logger.debug("Added tool handler: {}", toolSpecification.tool().name()); - - if (this.serverCapabilities.tools().listChanged()) { - return notifyToolsListChanged(); - } - return Mono.empty(); - }); - } - - @Override - public Mono removeTool(String toolName) { - if (toolName == null) { - return Mono.error(new McpError("Tool name must not be null")); - } - if (this.serverCapabilities.tools() == null) { - return Mono.error(new McpError("Server must be configured with tool capabilities")); - } + return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_MESSAGE, + loggingMessageNotification); + } + private McpServerSession.RequestHandler setLoggerRequestHandler() { + return (exchange, params) -> { return Mono.defer(() -> { - boolean removed = this.tools - .removeIf(toolSpecification -> toolSpecification.tool().name().equals(toolName)); - if (removed) { - logger.debug("Removed tool handler: {}", toolName); - if (this.serverCapabilities.tools().listChanged()) { - return notifyToolsListChanged(); - } - return Mono.empty(); - } - return Mono.error(new McpError("Tool with name '" + toolName + "' not found")); - }); - } - - @Override - public Mono notifyToolsListChanged() { - return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_TOOLS_LIST_CHANGED, null); - } - - private McpServerSession.RequestHandler toolsListRequestHandler() { - return (exchange, params) -> { - List tools = this.tools.stream().map(McpServerFeatures.AsyncToolSpecification::tool).toList(); - - return Mono.just(new McpSchema.ListToolsResult(tools, null)); - }; - } - private McpServerSession.RequestHandler toolsCallRequestHandler() { - return (exchange, params) -> { - McpSchema.CallToolRequest callToolRequest = objectMapper.convertValue(params, - new TypeReference() { + SetLevelRequest newMinLoggingLevel = objectMapper.convertValue(params, + new TypeReference() { }); - Optional toolSpecification = this.tools.stream() - .filter(tr -> callToolRequest.name().equals(tr.tool().name())) - .findAny(); + exchange.setMinLoggingLevel(newMinLoggingLevel.level()); - if (toolSpecification.isEmpty()) { - return Mono.error(new McpError("Tool not found: " + callToolRequest.name())); - } + // FIXME: this field is deprecated and should be removed together + // with the broadcasting loggingNotification. + this.minLoggingLevel = newMinLoggingLevel.level(); - return toolSpecification.map(tool -> tool.call().apply(exchange, callToolRequest.arguments())) - .orElse(Mono.error(new McpError("Tool not found: " + callToolRequest.name()))); - }; - } + return Mono.just(Map.of()); + }); + }; + } - // --------------------------------------- - // Resource Management - // --------------------------------------- + private McpServerSession.RequestHandler completionCompleteRequestHandler() { + return (exchange, params) -> { + McpSchema.CompleteRequest request = parseCompletionParams(params); - @Override - public Mono addResource(McpServerFeatures.AsyncResourceSpecification resourceSpecification) { - if (resourceSpecification == null || resourceSpecification.resource() == null) { - return Mono.error(new McpError("Resource must not be null")); + if (request.ref() == null) { + return Mono.error(new McpError("ref must not be null")); } - if (this.serverCapabilities.resources() == null) { - return Mono.error(new McpError("Server must be configured with resource capabilities")); + if (request.ref().type() == null) { + return Mono.error(new McpError("type must not be null")); } - return Mono.defer(() -> { - if (this.resources.putIfAbsent(resourceSpecification.resource().uri(), resourceSpecification) != null) { - return Mono.error(new McpError( - "Resource with URI '" + resourceSpecification.resource().uri() + "' already exists")); - } - logger.debug("Added resource handler: {}", resourceSpecification.resource().uri()); - if (this.serverCapabilities.resources().listChanged()) { - return notifyResourcesListChanged(); - } - return Mono.empty(); - }); - } + String type = request.ref().type(); - @Override - public Mono removeResource(String resourceUri) { - if (resourceUri == null) { - return Mono.error(new McpError("Resource URI must not be null")); - } - if (this.serverCapabilities.resources() == null) { - return Mono.error(new McpError("Server must be configured with resource capabilities")); - } + String argumentName = request.argument().name(); - return Mono.defer(() -> { - McpServerFeatures.AsyncResourceSpecification removed = this.resources.remove(resourceUri); - if (removed != null) { - logger.debug("Removed resource handler: {}", resourceUri); - if (this.serverCapabilities.resources().listChanged()) { - return notifyResourcesListChanged(); - } - return Mono.empty(); + // check if the referenced resource exists + if (type.equals("ref/prompt") && request.ref() instanceof McpSchema.PromptReference promptReference) { + McpServerFeatures.AsyncPromptSpecification promptSpec = this.prompts.get(promptReference.name()); + if (promptSpec == null) { + return Mono.error(new McpError("Prompt not found: " + promptReference.name())); } - return Mono.error(new McpError("Resource with URI '" + resourceUri + "' not found")); - }); - } - - @Override - public Mono notifyResourcesListChanged() { - return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_RESOURCES_LIST_CHANGED, null); - } - - private McpServerSession.RequestHandler resourcesListRequestHandler() { - return (exchange, params) -> { - var resourceList = this.resources.values() - .stream() - .map(McpServerFeatures.AsyncResourceSpecification::resource) - .toList(); - return Mono.just(new McpSchema.ListResourcesResult(resourceList, null)); - }; - } - - private McpServerSession.RequestHandler resourceTemplateListRequestHandler() { - return (exchange, params) -> Mono - .just(new McpSchema.ListResourceTemplatesResult(this.getResourceTemplates(), null)); - - } - - private List getResourceTemplates() { - var list = new ArrayList<>(this.resourceTemplates); - List resourceTemplates = this.resources.keySet() - .stream() - .filter(uri -> uri.contains("{")) - .map(uri -> { - var resource = this.resources.get(uri).resource(); - var template = new McpSchema.ResourceTemplate(resource.uri(), resource.name(), - resource.description(), resource.mimeType(), resource.annotations()); - return template; - }) - .toList(); - - list.addAll(resourceTemplates); - - return list; - } - - private McpServerSession.RequestHandler resourcesReadRequestHandler() { - return (exchange, params) -> { - McpSchema.ReadResourceRequest resourceRequest = objectMapper.convertValue(params, - new TypeReference() { - }); - var resourceUri = resourceRequest.uri(); - - McpServerFeatures.AsyncResourceSpecification specification = this.resources.values() + if (!promptSpec.prompt() + .arguments() .stream() - .filter(resourceSpecification -> this.uriTemplateManagerFactory - .create(resourceSpecification.resource().uri()) - .matches(resourceUri)) + .filter(arg -> arg.name().equals(argumentName)) .findFirst() - .orElseThrow(() -> new McpError("Resource not found: " + resourceUri)); - - return specification.readHandler().apply(exchange, resourceRequest); - }; - } - - // --------------------------------------- - // Prompt Management - // --------------------------------------- + .isPresent()) { - @Override - public Mono addPrompt(McpServerFeatures.AsyncPromptSpecification promptSpecification) { - if (promptSpecification == null) { - return Mono.error(new McpError("Prompt specification must not be null")); - } - if (this.serverCapabilities.prompts() == null) { - return Mono.error(new McpError("Server must be configured with prompt capabilities")); - } - - return Mono.defer(() -> { - McpServerFeatures.AsyncPromptSpecification specification = this.prompts - .putIfAbsent(promptSpecification.prompt().name(), promptSpecification); - if (specification != null) { - return Mono.error(new McpError( - "Prompt with name '" + promptSpecification.prompt().name() + "' already exists")); - } - - logger.debug("Added prompt handler: {}", promptSpecification.prompt().name()); - - // Servers that declared the listChanged capability SHOULD send a - // notification, - // when the list of available prompts changes - if (this.serverCapabilities.prompts().listChanged()) { - return notifyPromptsListChanged(); + return Mono.error(new McpError("Argument not found: " + argumentName)); } - return Mono.empty(); - }); - } - - @Override - public Mono removePrompt(String promptName) { - if (promptName == null) { - return Mono.error(new McpError("Prompt name must not be null")); - } - if (this.serverCapabilities.prompts() == null) { - return Mono.error(new McpError("Server must be configured with prompt capabilities")); } - return Mono.defer(() -> { - McpServerFeatures.AsyncPromptSpecification removed = this.prompts.remove(promptName); - - if (removed != null) { - logger.debug("Removed prompt handler: {}", promptName); - // Servers that declared the listChanged capability SHOULD send a - // notification, when the list of available prompts changes - if (this.serverCapabilities.prompts().listChanged()) { - return this.notifyPromptsListChanged(); - } - return Mono.empty(); + if (type.equals("ref/resource") && request.ref() instanceof McpSchema.ResourceReference resourceReference) { + McpServerFeatures.AsyncResourceSpecification resourceSpec = this.resources.get(resourceReference.uri()); + if (resourceSpec == null) { + return Mono.error(new McpError("Resource not found: " + resourceReference.uri())); } - return Mono.error(new McpError("Prompt with name '" + promptName + "' not found")); - }); - } - - @Override - public Mono notifyPromptsListChanged() { - return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_PROMPTS_LIST_CHANGED, null); - } - - private McpServerSession.RequestHandler promptsListRequestHandler() { - return (exchange, params) -> { - // TODO: Implement pagination - // McpSchema.PaginatedRequest request = objectMapper.convertValue(params, - // new TypeReference() { - // }); - - var promptList = this.prompts.values() - .stream() - .map(McpServerFeatures.AsyncPromptSpecification::prompt) - .toList(); - - return Mono.just(new McpSchema.ListPromptsResult(promptList, null)); - }; - } - - private McpServerSession.RequestHandler promptsGetRequestHandler() { - return (exchange, params) -> { - McpSchema.GetPromptRequest promptRequest = objectMapper.convertValue(params, - new TypeReference() { - }); - - // Implement prompt retrieval logic here - McpServerFeatures.AsyncPromptSpecification specification = this.prompts.get(promptRequest.name()); - if (specification == null) { - return Mono.error(new McpError("Prompt not found: " + promptRequest.name())); + if (!uriTemplateManagerFactory.create(resourceSpec.resource().uri()) + .getVariableNames() + .contains(argumentName)) { + return Mono.error(new McpError("Argument not found: " + argumentName)); } - return specification.promptHandler().apply(exchange, promptRequest); - }; - } - - // --------------------------------------- - // Logging Management - // --------------------------------------- - - @Override - public Mono loggingNotification(LoggingMessageNotification loggingMessageNotification) { - - if (loggingMessageNotification == null) { - return Mono.error(new McpError("Logging message must not be null")); - } - - if (loggingMessageNotification.level().level() < minLoggingLevel.level()) { - return Mono.empty(); } - return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_MESSAGE, - loggingMessageNotification); - } - - private McpServerSession.RequestHandler setLoggerRequestHandler() { - return (exchange, params) -> { - return Mono.defer(() -> { - - SetLevelRequest newMinLoggingLevel = objectMapper.convertValue(params, - new TypeReference() { - }); - - exchange.setMinLoggingLevel(newMinLoggingLevel.level()); - - // FIXME: this field is deprecated and should be removed together - // with the broadcasting loggingNotification. - this.minLoggingLevel = newMinLoggingLevel.level(); - - return Mono.just(Map.of()); - }); - }; - } - - private McpServerSession.RequestHandler completionCompleteRequestHandler() { - return (exchange, params) -> { - McpSchema.CompleteRequest request = parseCompletionParams(params); - - if (request.ref() == null) { - return Mono.error(new McpError("ref must not be null")); - } - - if (request.ref().type() == null) { - return Mono.error(new McpError("type must not be null")); - } - - String type = request.ref().type(); - - String argumentName = request.argument().name(); - - // check if the referenced resource exists - if (type.equals("ref/prompt") && request.ref() instanceof McpSchema.PromptReference promptReference) { - McpServerFeatures.AsyncPromptSpecification promptSpec = this.prompts.get(promptReference.name()); - if (promptSpec == null) { - return Mono.error(new McpError("Prompt not found: " + promptReference.name())); - } - if (!promptSpec.prompt() - .arguments() - .stream() - .filter(arg -> arg.name().equals(argumentName)) - .findFirst() - .isPresent()) { - - return Mono.error(new McpError("Argument not found: " + argumentName)); - } - } - - if (type.equals("ref/resource") - && request.ref() instanceof McpSchema.ResourceReference resourceReference) { - McpServerFeatures.AsyncResourceSpecification resourceSpec = this.resources - .get(resourceReference.uri()); - if (resourceSpec == null) { - return Mono.error(new McpError("Resource not found: " + resourceReference.uri())); - } - if (!uriTemplateManagerFactory.create(resourceSpec.resource().uri()) - .getVariableNames() - .contains(argumentName)) { - return Mono.error(new McpError("Argument not found: " + argumentName)); - } + McpServerFeatures.AsyncCompletionSpecification specification = this.completions.get(request.ref()); - } - - McpServerFeatures.AsyncCompletionSpecification specification = this.completions.get(request.ref()); - - if (specification == null) { - return Mono.error(new McpError("AsyncCompletionSpecification not found: " + request.ref())); - } - - return specification.completionHandler().apply(exchange, request); - }; - } - - /** - * Parses the raw JSON-RPC request parameters into a - * {@link McpSchema.CompleteRequest} object. - *

- * This method manually extracts the `ref` and `argument` fields from the input - * map, determines the correct reference type (either prompt or resource), and - * constructs a fully-typed {@code CompleteRequest} instance. - * @param object the raw request parameters, expected to be a Map containing "ref" - * and "argument" entries. - * @return a {@link McpSchema.CompleteRequest} representing the structured - * completion request. - * @throws IllegalArgumentException if the "ref" type is not recognized. - */ - @SuppressWarnings("unchecked") - private McpSchema.CompleteRequest parseCompletionParams(Object object) { - Map params = (Map) object; - Map refMap = (Map) params.get("ref"); - Map argMap = (Map) params.get("argument"); - - String refType = (String) refMap.get("type"); - - McpSchema.CompleteReference ref = switch (refType) { - case "ref/prompt" -> new McpSchema.PromptReference(refType, (String) refMap.get("name")); - case "ref/resource" -> new McpSchema.ResourceReference(refType, (String) refMap.get("uri")); - default -> throw new IllegalArgumentException("Invalid ref type: " + refType); - }; - - String argName = (String) argMap.get("name"); - String argValue = (String) argMap.get("value"); - McpSchema.CompleteRequest.CompleteArgument argument = new McpSchema.CompleteRequest.CompleteArgument( - argName, argValue); - - return new McpSchema.CompleteRequest(ref, argument); - } + if (specification == null) { + return Mono.error(new McpError("AsyncCompletionSpecification not found: " + request.ref())); + } - // --------------------------------------- - // Sampling - // --------------------------------------- + return specification.completionHandler().apply(exchange, request); + }; + } - @Override - void setProtocolVersions(List protocolVersions) { - this.protocolVersions = protocolVersions; - } + /** + * Parses the raw JSON-RPC request parameters into a {@link McpSchema.CompleteRequest} + * object. + *

+ * This method manually extracts the `ref` and `argument` fields from the input map, + * determines the correct reference type (either prompt or resource), and constructs a + * fully-typed {@code CompleteRequest} instance. + * @param object the raw request parameters, expected to be a Map containing "ref" and + * "argument" entries. + * @return a {@link McpSchema.CompleteRequest} representing the structured completion + * request. + * @throws IllegalArgumentException if the "ref" type is not recognized. + */ + @SuppressWarnings("unchecked") + private McpSchema.CompleteRequest parseCompletionParams(Object object) { + Map params = (Map) object; + Map refMap = (Map) params.get("ref"); + Map argMap = (Map) params.get("argument"); + + String refType = (String) refMap.get("type"); + + McpSchema.CompleteReference ref = switch (refType) { + case "ref/prompt" -> new McpSchema.PromptReference(refType, (String) refMap.get("name")); + case "ref/resource" -> new McpSchema.ResourceReference(refType, (String) refMap.get("uri")); + default -> throw new IllegalArgumentException("Invalid ref type: " + refType); + }; + + String argName = (String) argMap.get("name"); + String argValue = (String) argMap.get("value"); + McpSchema.CompleteRequest.CompleteArgument argument = new McpSchema.CompleteRequest.CompleteArgument(argName, + argValue); + + return new McpSchema.CompleteRequest(ref, argument); + } + /** + * This method is package-private and used for test only. Should not be called by user + * code. + * @param protocolVersions the Client supported protocol versions. + */ + void setProtocolVersions(List protocolVersions) { + this.protocolVersions = protocolVersions; } } From b2d3e0098e484e172719237b0933fa395cdfdf4b Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Mon, 12 May 2025 15:04:05 +0200 Subject: [PATCH 047/205] Next development version Signed-off-by: Christian Tzolov --- mcp-bom/pom.xml | 2 +- mcp-spring/mcp-spring-webflux/pom.xml | 6 +++--- mcp-spring/mcp-spring-webmvc/pom.xml | 6 +++--- mcp-test/pom.xml | 4 ++-- mcp/pom.xml | 2 +- pom.xml | 2 +- 6 files changed, 11 insertions(+), 11 deletions(-) diff --git a/mcp-bom/pom.xml b/mcp-bom/pom.xml index 4f24f719f..7214dacda 100644 --- a/mcp-bom/pom.xml +++ b/mcp-bom/pom.xml @@ -7,7 +7,7 @@ io.modelcontextprotocol.sdk mcp-parent - 0.10.0-SNAPSHOT + 0.11.0-SNAPSHOT mcp-bom diff --git a/mcp-spring/mcp-spring-webflux/pom.xml b/mcp-spring/mcp-spring-webflux/pom.xml index 86f46bf95..a8b92bd09 100644 --- a/mcp-spring/mcp-spring-webflux/pom.xml +++ b/mcp-spring/mcp-spring-webflux/pom.xml @@ -6,7 +6,7 @@ io.modelcontextprotocol.sdk mcp-parent - 0.10.0-SNAPSHOT + 0.11.0-SNAPSHOT ../../pom.xml mcp-spring-webflux @@ -25,13 +25,13 @@ io.modelcontextprotocol.sdk mcp - 0.10.0-SNAPSHOT + 0.11.0-SNAPSHOT io.modelcontextprotocol.sdk mcp-test - 0.10.0-SNAPSHOT + 0.11.0-SNAPSHOT test diff --git a/mcp-spring/mcp-spring-webmvc/pom.xml b/mcp-spring/mcp-spring-webmvc/pom.xml index 82fbbf3e6..48d1c3465 100644 --- a/mcp-spring/mcp-spring-webmvc/pom.xml +++ b/mcp-spring/mcp-spring-webmvc/pom.xml @@ -6,7 +6,7 @@ io.modelcontextprotocol.sdk mcp-parent - 0.10.0-SNAPSHOT + 0.11.0-SNAPSHOT ../../pom.xml mcp-spring-webmvc @@ -25,13 +25,13 @@ io.modelcontextprotocol.sdk mcp - 0.10.0-SNAPSHOT + 0.11.0-SNAPSHOT io.modelcontextprotocol.sdk mcp-test - 0.10.0-SNAPSHOT + 0.11.0-SNAPSHOT test diff --git a/mcp-test/pom.xml b/mcp-test/pom.xml index f1484ae77..a6e5bdb08 100644 --- a/mcp-test/pom.xml +++ b/mcp-test/pom.xml @@ -6,7 +6,7 @@ io.modelcontextprotocol.sdk mcp-parent - 0.10.0-SNAPSHOT + 0.11.0-SNAPSHOT mcp-test jar @@ -24,7 +24,7 @@ io.modelcontextprotocol.sdk mcp - 0.10.0-SNAPSHOT + 0.11.0-SNAPSHOT diff --git a/mcp/pom.xml b/mcp/pom.xml index 17693ab32..773432827 100644 --- a/mcp/pom.xml +++ b/mcp/pom.xml @@ -6,7 +6,7 @@ io.modelcontextprotocol.sdk mcp-parent - 0.10.0-SNAPSHOT + 0.11.0-SNAPSHOT mcp jar diff --git a/pom.xml b/pom.xml index 638457406..c2327ee8d 100644 --- a/pom.xml +++ b/pom.xml @@ -6,7 +6,7 @@ io.modelcontextprotocol.sdk mcp-parent - 0.10.0-SNAPSHOT + 0.11.0-SNAPSHOT pom https://github.com/modelcontextprotocol/java-sdk From f34662555a0ab68d74ac118f1b0220441b2c81b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Wed, 14 May 2025 15:38:02 +0200 Subject: [PATCH 048/205] Fix stdio tests - proper server-everything argument (#237) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Dariusz Jędrzejczyk --- .../modelcontextprotocol/client/StdioMcpAsyncClientTests.java | 4 ++-- .../modelcontextprotocol/client/StdioMcpSyncClientTests.java | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java index c39080138..8c0069d6d 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java @@ -25,12 +25,12 @@ protected McpClientTransport createMcpTransport() { ServerParameters stdioParams; if (System.getProperty("os.name").toLowerCase().contains("win")) { stdioParams = ServerParameters.builder("cmd.exe") - .args("/c", "npx.cmd", "-y", "@modelcontextprotocol/server-everything", "dir") + .args("/c", "npx.cmd", "-y", "@modelcontextprotocol/server-everything", "stdio") .build(); } else { stdioParams = ServerParameters.builder("npx") - .args("-y", "@modelcontextprotocol/server-everything", "dir") + .args("-y", "@modelcontextprotocol/server-everything", "stdio") .build(); } return new StdioClientTransport(stdioParams); diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java index 8e75c4a3d..706aa9b2e 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java @@ -33,12 +33,12 @@ protected McpClientTransport createMcpTransport() { ServerParameters stdioParams; if (System.getProperty("os.name").toLowerCase().contains("win")) { stdioParams = ServerParameters.builder("cmd.exe") - .args("/c", "npx.cmd", "-y", "@modelcontextprotocol/server-everything", "dir") + .args("/c", "npx.cmd", "-y", "@modelcontextprotocol/server-everything", "stdio") .build(); } else { stdioParams = ServerParameters.builder("npx") - .args("-y", "@modelcontextprotocol/server-everything", "dir") + .args("-y", "@modelcontextprotocol/server-everything", "stdio") .build(); } return new StdioClientTransport(stdioParams); From 2e13f9f9df8610e0d05cc76b1416fe195e249303 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Wed, 14 May 2025 22:46:54 +0200 Subject: [PATCH 049/205] Fix flaky WebFluxSse integration test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Dariusz Jędrzejczyk --- .../WebFluxSseIntegrationTests.java | 46 ++++++++++--------- 1 file changed, 24 insertions(+), 22 deletions(-) diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java index 2ba047461..03fbc9962 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java @@ -8,6 +8,8 @@ import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiFunction; @@ -651,9 +653,11 @@ void testInitialize(String clientType) { @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "httpclient", "webflux" }) - void testLoggingNotification(String clientType) { + void testLoggingNotification(String clientType) throws InterruptedException { + int expectedNotificationsCount = 3; + CountDownLatch latch = new CountDownLatch(expectedNotificationsCount); // Create a list to store received logging notifications - List receivedNotifications = new ArrayList<>(); + List receivedNotifications = new CopyOnWriteArrayList<>(); var clientBuilder = clientBuilders.get(clientType); @@ -709,6 +713,7 @@ void testLoggingNotification(String clientType) { // Create client with logging notification handler var mcpClient = clientBuilder.loggingConsumer(notification -> { receivedNotifications.add(notification); + latch.countDown(); }).build()) { // Initialize client @@ -724,31 +729,28 @@ void testLoggingNotification(String clientType) { assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Logging test completed"); - // Wait for notifications to be processed - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(latch.await(5, TimeUnit.SECONDS)).as("Should receive notifications in reasonable time").isTrue(); - // Should have received 3 notifications (1 NOTICE and 2 ERROR) - assertThat(receivedNotifications).hasSize(3); + // Should have received 3 notifications (1 NOTICE and 2 ERROR) + assertThat(receivedNotifications).hasSize(expectedNotificationsCount); - Map notificationMap = receivedNotifications.stream() - .collect(Collectors.toMap(n -> n.data(), n -> n)); + Map notificationMap = receivedNotifications.stream() + .collect(Collectors.toMap(n -> n.data(), n -> n)); - // First notification should be NOTICE level - assertThat(notificationMap.get("Notice message").level()).isEqualTo(McpSchema.LoggingLevel.NOTICE); - assertThat(notificationMap.get("Notice message").logger()).isEqualTo("test-logger"); - assertThat(notificationMap.get("Notice message").data()).isEqualTo("Notice message"); + // First notification should be NOTICE level + assertThat(notificationMap.get("Notice message").level()).isEqualTo(McpSchema.LoggingLevel.NOTICE); + assertThat(notificationMap.get("Notice message").logger()).isEqualTo("test-logger"); + assertThat(notificationMap.get("Notice message").data()).isEqualTo("Notice message"); - // Second notification should be ERROR level - assertThat(notificationMap.get("Error message").level()).isEqualTo(McpSchema.LoggingLevel.ERROR); - assertThat(notificationMap.get("Error message").logger()).isEqualTo("test-logger"); - assertThat(notificationMap.get("Error message").data()).isEqualTo("Error message"); + // Second notification should be ERROR level + assertThat(notificationMap.get("Error message").level()).isEqualTo(McpSchema.LoggingLevel.ERROR); + assertThat(notificationMap.get("Error message").logger()).isEqualTo("test-logger"); + assertThat(notificationMap.get("Error message").data()).isEqualTo("Error message"); - // Third notification should be ERROR level - assertThat(notificationMap.get("Another error message").level()) - .isEqualTo(McpSchema.LoggingLevel.ERROR); - assertThat(notificationMap.get("Another error message").logger()).isEqualTo("test-logger"); - assertThat(notificationMap.get("Another error message").data()).isEqualTo("Another error message"); - }); + // Third notification should be ERROR level + assertThat(notificationMap.get("Another error message").level()).isEqualTo(McpSchema.LoggingLevel.ERROR); + assertThat(notificationMap.get("Another error message").logger()).isEqualTo("test-logger"); + assertThat(notificationMap.get("Another error message").data()).isEqualTo("Another error message"); } mcpServer.close(); } From 1adfa8a047852c8f9e0188b4e63fe2020e0c66c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Wed, 14 May 2025 14:05:39 +0200 Subject: [PATCH 050/205] Add Contributing Guidelines and Code of Conduct MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Dariusz Jędrzejczyk --- CODE_OF_CONDUCT.md | 119 +++++++++++++++++++++++++++++++++++++++++++++ CONTRIBUTING.md | 91 ++++++++++++++++++++++++++++++++++ README.md | 7 +-- SECURITY.md | 21 ++++++++ 4 files changed, 233 insertions(+), 5 deletions(-) create mode 100644 CODE_OF_CONDUCT.md create mode 100644 CONTRIBUTING.md create mode 100644 SECURITY.md diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 000000000..6009a645f --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,119 @@ +# Contributor Covenant Code of Conduct + +## Our Pledge + +We as members, contributors, and leaders pledge to make participation in our community a +harassment-free experience for everyone, regardless of age, body size, visible or +invisible disability, ethnicity, sex characteristics, gender identity and expression, +level of experience, education, socio-economic status, nationality, personal appearance, +race, religion, or sexual identity and orientation. + +We pledge to act and interact in ways that contribute to an open, welcoming, diverse, +inclusive, and healthy community. + +## Our Standards + +Examples of behavior that contributes to a positive environment for our community +include: + +- Demonstrating empathy and kindness toward other people +- Being respectful of differing opinions, viewpoints, and experiences +- Giving and gracefully accepting constructive feedback +- Accepting responsibility and apologizing to those affected by our mistakes, and + learning from the experience +- Focusing on what is best not just for us as individuals, but for the overall community + +Examples of unacceptable behavior include: + +- The use of sexualized language or imagery, and sexual attention or advances of any kind +- Trolling, insulting or derogatory comments, and personal or political attacks +- Public or private harassment +- Publishing others' private information, such as a physical or email address, without + their explicit permission +- Other conduct which could reasonably be considered inappropriate in a professional + setting + +## Enforcement Responsibilities + +Community leaders are responsible for clarifying and enforcing our standards of +acceptable behavior and will take appropriate and fair corrective action in response to +any behavior that they deem inappropriate, threatening, offensive, or harmful. + +Community leaders have the right and responsibility to remove, edit, or reject comments, +commits, code, wiki edits, issues, and other contributions that are not aligned to this +Code of Conduct, and will communicate reasons for moderation decisions when appropriate. + +## Scope + +This Code of Conduct applies within all community spaces, and also applies when an +individual is officially representing the community in public spaces. Examples of +representing our community include using an official e-mail address, posting via an +official social media account, or acting as an appointed representative at an online or +offline event. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be reported to +the community leaders responsible for enforcement at mcp-coc@anthropic.com. All +complaints will be reviewed and investigated promptly and fairly. + +All community leaders are obligated to respect the privacy and security of the reporter +of any incident. + +## Enforcement Guidelines + +Community leaders will follow these Community Impact Guidelines in determining the +consequences for any action they deem in violation of this Code of Conduct: + +### 1. Correction + +**Community Impact**: Use of inappropriate language or other behavior deemed +unprofessional or unwelcome in the community. + +**Consequence**: A private, written warning from community leaders, providing clarity +around the nature of the violation and an explanation of why the behavior was +inappropriate. A public apology may be requested. + +### 2. Warning + +**Community Impact**: A violation through a single incident or series of actions. + +**Consequence**: A warning with consequences for continued behavior. No interaction with +the people involved, including unsolicited interaction with those enforcing the Code of +Conduct, for a specified period of time. This includes avoiding interactions in community +spaces as well as external channels like social media. Violating these terms may lead to +a temporary or permanent ban. + +### 3. Temporary Ban + +**Community Impact**: A serious violation of community standards, including sustained +inappropriate behavior. + +**Consequence**: A temporary ban from any sort of interaction or public communication +with the community for a specified period of time. No public or private interaction with +the people involved, including unsolicited interaction with those enforcing the Code of +Conduct, is allowed during this period. Violating these terms may lead to a permanent +ban. + +### 4. Permanent Ban + +**Community Impact**: Demonstrating a pattern of violation of community standards, +including sustained inappropriate behavior, harassment of an individual, or aggression +toward or disparagement of classes of individuals. + +**Consequence**: A permanent ban from any sort of public interaction within the +community. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 2.0, +available at https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. + +Community Impact Guidelines were inspired by +[Mozilla's code of conduct enforcement ladder](https://github.com/mozilla/diversity). + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see the FAQ at +https://www.contributor-covenant.org/faq. Translations are available at +https://www.contributor-covenant.org/translations. \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 000000000..a949dcc09 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,91 @@ +# Contributing to Model Context Protocol Java SDK + +Thank you for your interest in contributing to the Model Context Protocol Java SDK! +This document outlines how to contribute to this project. + +## Prerequisites + +The following software is required to work on the codebase: + +- `Java 17` or above +- `Docker` +- `npx` + +## Getting Started + +1. Fork the repository +2. Clone your fork: + +```bash +git clone https://github.com/YOUR-USERNAME/java-sdk.git +cd java-sdk +``` + +3. Build from source: + +```bash +./mvnw clean install -DskipTests # skip the tests +./mvnw test # run tests +``` + +## Reporting Issues + +Please create an issue in the repository if you discover a bug or would like to +propose an enhancement. Bug reports should have a reproducer in the form of a code +sample or a repository attached that the maintainers or contributors can work with to +address the problem. + +## Making Changes + +1. Create a new branch: + +```bash +git checkout -b feature/your-feature-name +``` + +2. Make your changes +3. Validate your changes: + +```bash +./mvnw clean test +``` + +### Change Proposal Guidelines + +#### Principles of MCP + +1. **Simple + Minimal**: It is much easier to add things to the codebase than it is to + remove them. To maintain simplicity, we keep a high bar for adding new concepts and + primitives as each addition requires maintenance and compatibility consideration. +2. **Concrete**: Code changes need to be based on specific usage and implementation + challenges and not on speculative ideas. Most importantly, the SDK is meant to + implement the MCP specification. + +## Submitting Changes + +1. For non-trivial changes, please clarify with the maintainers in an issue whether + you can contribute the change and the desired scope of the change. +2. For trivial changes (for example a couple of lines or documentation changes) there + is no need to open an issue first. +3. Push your changes to your fork. +4. Submit a pull request to the main repository. +5. Follow the pull request template. +6. Wait for review. + +## Code of Conduct + +This project follows a Code of Conduct. Please review it in +[CODE_OF_CONDUCT.md](CODE_OF_CONDUCT.md). + +## Questions + +If you have questions, please create a discussion in the repository. + +## License + +By contributing, you agree that your contributions will be licensed under the MIT +License. + +## Security + +Please review our [Security Policy](SECURITY.md) for reporting security issues. \ No newline at end of file diff --git a/README.md b/README.md index 9fc17306e..0cd3f84a4 100644 --- a/README.md +++ b/README.md @@ -30,11 +30,8 @@ To run the tests you have to pre-install `Docker` and `npx`. ## Contributing -Contributions are welcome! Please: - -1. Fork the repository -2. Create a feature branch -3. Submit a Pull Request +Contributions are welcome! +Please follow the [Contributing Guidelines](CONTRIBUTING.md). ## Team diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 000000000..74e9880fd --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,21 @@ +# Security Policy + +Thank you for helping us keep the SDKs and systems they interact with secure. + +## Reporting Security Issues + +This SDK is maintained by [Anthropic](https://www.anthropic.com/) as part of the Model +Context Protocol project. + +The security of our systems and user data is Anthropic’s top priority. We appreciate the +work of security researchers acting in good faith in identifying and reporting potential +vulnerabilities. + +Our security program is managed on HackerOne and we ask that any validated vulnerability +in this functionality be reported through their +[submission form](https://hackerone.com/anthropic-vdp/reports/new?type=team&report_type=vulnerability). + +## Vulnerability Disclosure Program + +Our Vulnerability Program Guidelines are defined on our +[HackerOne program page](https://hackerone.com/anthropic-vdp). \ No newline at end of file From 07e7b8fd6bac47be4527f97451f8cdd95ed31a38 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Wed, 14 May 2025 18:00:06 +0200 Subject: [PATCH 051/205] Add note about force pushes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Dariusz Jędrzejczyk --- CONTRIBUTING.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index a949dcc09..517f32555 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -71,6 +71,9 @@ git checkout -b feature/your-feature-name 4. Submit a pull request to the main repository. 5. Follow the pull request template. 6. Wait for review. +7. For any follow-up work, please add new commits instead of force-pushing. This will + allow the reviewer to focus on incremental changes instead of having to restart the + review process. ## Code of Conduct From 8a5a591d39256ba3947003ec4477e1722363eb35 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Tue, 27 May 2025 15:26:44 -0700 Subject: [PATCH 052/205] feat: Add elicitation support to MCP protocol Implement elicitation capabilities allowing servers to request additional information from users through clients during interactions. This feature provides a standardized way for servers to gather necessary information dynamically while clients maintain control over user interactions and data sharing. - Add ElicitRequest and ElicitResult classes to McpSchema - Implement elicitation handlers in client classes - Add elicitation capabilities to server exchange classes - Add tests for elicitation functionality with various scenarios --- .../WebFluxSseIntegrationTests.java | 224 +++++++++++++++++- .../server/WebMvcSseIntegrationTests.java | 213 +++++++++++++++++ .../client/McpAsyncClient.java | 32 +++ .../client/McpClient.java | 40 +++- .../client/McpClientFeatures.java | 31 ++- .../server/McpAsyncServerExchange.java | 28 +++ .../server/McpSyncServerExchange.java | 18 ++ .../modelcontextprotocol/spec/McpSchema.java | 129 ++++++++-- .../client/AbstractMcpAsyncClientTests.java | 22 +- .../McpAsyncClientResponseHandlerTests.java | 150 ++++++++++++ ...rverTransportProviderIntegrationTests.java | 213 +++++++++++++++++ .../spec/McpSchemaTests.java | 34 +++ 12 files changed, 1106 insertions(+), 28 deletions(-) diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java index 03fbc9962..2f85654e8 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java @@ -4,7 +4,6 @@ package io.modelcontextprotocol; import java.time.Duration; -import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; @@ -28,11 +27,11 @@ import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.*; -import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities.CompletionCapabilities; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; +import reactor.core.publisher.Mono; import reactor.netty.DisposableServer; import reactor.netty.http.server.HttpServer; @@ -41,6 +40,7 @@ import org.springframework.web.client.RestClient; import org.springframework.web.reactive.function.client.WebClient; import org.springframework.web.reactive.function.server.RouterFunctions; +import reactor.test.StepVerifier; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; @@ -331,6 +331,226 @@ void testCreateMessageWithRequestTimeoutFail(String clientType) throws Interrupt mcpServer.closeGracefully().block(); } + // --------------------------------------- + // Elicitation Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateElicitationWithoutElicitationCapabilities(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + exchange.createElicitation(mock(ElicitRequest.class)).block(); + + return Mono.just(mock(CallToolResult.class)); + }); + + var server = McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").tools(tool).build(); + + try ( + // Create client without elicitation capabilities + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")).build()) { + + assertThat(client.initialize()).isNotNull(); + + try { + client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be configured with elicitation capabilities"); + } + } + server.closeGracefully().block(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateElicitationSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + + return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); + }; + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var elicitationRequest = ElicitRequest.builder() + .message("Test message") + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + } + mcpServer.closeGracefully().block(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateElicitationWithRequestTimeoutSuccess(String clientType) { + + // Client + var clientBuilder = clientBuilders.get(clientType); + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + try { + TimeUnit.SECONDS.sleep(2); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build(); + + // Server + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var elicitationRequest = ElicitRequest.builder() + .message("Test message") + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(3)) + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + mcpClient.closeGracefully(); + mcpServer.closeGracefully().block(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateElicitationWithRequestTimeoutFail(String clientType) { + + // Client + var clientBuilder = clientBuilders.get(clientType); + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + try { + TimeUnit.SECONDS.sleep(2); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build(); + + // Server + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var elicitationRequest = ElicitRequest.builder() + .message("Test message") + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(1)) + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThatExceptionOfType(McpError.class).isThrownBy(() -> { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + }).withMessageContaining("within 1000ms"); + + mcpClient.closeGracefully(); + mcpServer.closeGracefully().block(); + } + // --------------------------------------- // Roots Tests // --------------------------------------- diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java index b12d68439..3f3f7be62 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java @@ -357,6 +357,219 @@ void testCreateMessageWithRequestTimeoutFail() throws InterruptedException { mcpServer.close(); } + // --------------------------------------- + // Elicitation Tests + // --------------------------------------- + @Test + void testCreateElicitationWithoutElicitationCapabilities() { + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + exchange.createElicitation(mock(McpSchema.ElicitRequest.class)).block(); + + return Mono.just(mock(CallToolResult.class)); + }); + + var server = McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").tools(tool).build(); + + try ( + // Create client without elicitation capabilities + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")).build()) { + + assertThat(client.initialize()).isNotNull(); + + try { + client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be configured with elicitation capabilities"); + } + } + server.closeGracefully().block(); + } + + @Test + void testCreateElicitationSuccess() { + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + + return new McpSchema.ElicitResult(McpSchema.ElicitResult.Action.ACCEPT, + Map.of("message", request.message())); + }; + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var elicitationRequest = McpSchema.ElicitRequest.builder() + .message("Test message") + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + } + mcpServer.closeGracefully().block(); + } + + @Test + void testCreateElicitationWithRequestTimeoutSuccess() { + + // Client + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + try { + TimeUnit.SECONDS.sleep(2); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new McpSchema.ElicitResult(McpSchema.ElicitResult.Action.ACCEPT, + Map.of("message", request.message())); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build(); + + // Server + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var elicitationRequest = McpSchema.ElicitRequest.builder() + .message("Test message") + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(3)) + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + mcpClient.closeGracefully(); + mcpServer.closeGracefully().block(); + } + + @Test + void testCreateElicitationWithRequestTimeoutFail() { + + // Client + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + try { + TimeUnit.SECONDS.sleep(2); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new McpSchema.ElicitResult(McpSchema.ElicitResult.Action.ACCEPT, + Map.of("message", request.message())); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build(); + + // Server + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var elicitationRequest = McpSchema.ElicitRequest.builder() + .message("Test message") + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(1)) + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThatExceptionOfType(McpError.class).isThrownBy(() -> { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + }).withMessageContaining("Timeout"); + + mcpClient.closeGracefully(); + mcpServer.closeGracefully().block(); + } + // --------------------------------------- // Roots Tests // --------------------------------------- diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java index e3a997ba3..a22ef6b51 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java @@ -23,6 +23,8 @@ import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.ElicitRequest; +import io.modelcontextprotocol.spec.McpSchema.ElicitResult; import io.modelcontextprotocol.spec.McpSchema.GetPromptRequest; import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; import io.modelcontextprotocol.spec.McpSchema.ListPromptsResult; @@ -141,6 +143,15 @@ public class McpAsyncClient { */ private Function> samplingHandler; + /** + * MCP provides a standardized way for servers to request additional information from + * users through the client during interactions. This flow allows clients to maintain + * control over user interactions and data sharing while enabling servers to gather + * necessary information dynamically. Servers can request structured data from users + * with optional JSON schemas to validate responses. + */ + private Function> elicitationHandler; + /** * Client transport implementation. */ @@ -189,6 +200,15 @@ public class McpAsyncClient { requestHandlers.put(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE, samplingCreateMessageHandler()); } + // Elicitation Handler + if (this.clientCapabilities.elicitation() != null) { + if (features.elicitationHandler() == null) { + throw new McpError("Elicitation handler must not be null when client capabilities include elicitation"); + } + this.elicitationHandler = features.elicitationHandler(); + requestHandlers.put(McpSchema.METHOD_ELICITATION_CREATE, elicitationCreateHandler()); + } + // Notification Handlers Map notificationHandlers = new HashMap<>(); @@ -500,6 +520,18 @@ private RequestHandler samplingCreateMessageHandler() { }; } + // -------------------------- + // Elicitation + // -------------------------- + private RequestHandler elicitationCreateHandler() { + return params -> { + ElicitRequest request = transport.unmarshalFrom(params, new TypeReference<>() { + }); + + return this.elicitationHandler.apply(request); + }; + } + // -------------------------- // Tools // -------------------------- diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java index a1dc11685..280906cff 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java @@ -18,6 +18,8 @@ import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.ElicitRequest; +import io.modelcontextprotocol.spec.McpSchema.ElicitResult; import io.modelcontextprotocol.spec.McpSchema.Implementation; import io.modelcontextprotocol.spec.McpSchema.Root; import io.modelcontextprotocol.util.Assert; @@ -175,6 +177,8 @@ class SyncSpec { private Function samplingHandler; + private Function elicitationHandler; + private SyncSpec(McpClientTransport transport) { Assert.notNull(transport, "Transport must not be null"); this.transport = transport; @@ -283,6 +287,21 @@ public SyncSpec sampling(Function sam return this; } + /** + * Sets a custom elicitation handler for processing elicitation message requests. + * The elicitation handler can modify or validate messages before they are sent to + * the server, enabling custom processing logic. + * @param elicitationHandler A function that processes elicitation requests and + * returns results. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if elicitationHandler is null + */ + public SyncSpec elicitation(Function elicitationHandler) { + Assert.notNull(elicitationHandler, "Elicitation handler must not be null"); + this.elicitationHandler = elicitationHandler; + return this; + } + /** * Adds a consumer to be notified when the available tools change. This allows the * client to react to changes in the server's tool capabilities, such as tools @@ -364,7 +383,7 @@ public SyncSpec loggingConsumers(List> samplingHandler; + private Function> elicitationHandler; + private AsyncSpec(McpClientTransport transport) { Assert.notNull(transport, "Transport must not be null"); this.transport = transport; @@ -522,6 +543,21 @@ public AsyncSpec sampling(Function> elicitationHandler) { + Assert.notNull(elicitationHandler, "Elicitation handler must not be null"); + this.elicitationHandler = elicitationHandler; + return this; + } + /** * Adds a consumer to be notified when the available tools change. This allows the * client to react to changes in the server's tool capabilities, such as tools @@ -606,7 +642,7 @@ public McpAsyncClient build() { return new McpAsyncClient(this.transport, this.requestTimeout, this.initializationTimeout, new McpClientFeatures.Async(this.clientInfo, this.capabilities, this.roots, this.toolsChangeConsumers, this.resourcesChangeConsumers, this.promptsChangeConsumers, - this.loggingConsumers, this.samplingHandler)); + this.loggingConsumers, this.samplingHandler, this.elicitationHandler)); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java index 284b93f88..23d7c6a60 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java @@ -60,13 +60,15 @@ class McpClientFeatures { * @param promptsChangeConsumers the prompts change consumers. * @param loggingConsumers the logging consumers. * @param samplingHandler the sampling handler. + * @param elicitationHandler the elicitation handler. */ record Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities clientCapabilities, Map roots, List, Mono>> toolsChangeConsumers, List, Mono>> resourcesChangeConsumers, List, Mono>> promptsChangeConsumers, List>> loggingConsumers, - Function> samplingHandler) { + Function> samplingHandler, + Function> elicitationHandler) { /** * Create an instance and validate the arguments. @@ -77,6 +79,7 @@ record Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c * @param promptsChangeConsumers the prompts change consumers. * @param loggingConsumers the logging consumers. * @param samplingHandler the sampling handler. + * @param elicitationHandler the elicitation handler. */ public Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities clientCapabilities, Map roots, @@ -84,14 +87,16 @@ public Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c List, Mono>> resourcesChangeConsumers, List, Mono>> promptsChangeConsumers, List>> loggingConsumers, - Function> samplingHandler) { + Function> samplingHandler, + Function> elicitationHandler) { Assert.notNull(clientInfo, "Client info must not be null"); this.clientInfo = clientInfo; this.clientCapabilities = (clientCapabilities != null) ? clientCapabilities : new McpSchema.ClientCapabilities(null, !Utils.isEmpty(roots) ? new McpSchema.ClientCapabilities.RootCapabilities(false) : null, - samplingHandler != null ? new McpSchema.ClientCapabilities.Sampling() : null); + samplingHandler != null ? new McpSchema.ClientCapabilities.Sampling() : null, + elicitationHandler != null ? new McpSchema.ClientCapabilities.Elicitation() : null); this.roots = roots != null ? new ConcurrentHashMap<>(roots) : new ConcurrentHashMap<>(); this.toolsChangeConsumers = toolsChangeConsumers != null ? toolsChangeConsumers : List.of(); @@ -99,6 +104,7 @@ public Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c this.promptsChangeConsumers = promptsChangeConsumers != null ? promptsChangeConsumers : List.of(); this.loggingConsumers = loggingConsumers != null ? loggingConsumers : List.of(); this.samplingHandler = samplingHandler; + this.elicitationHandler = elicitationHandler; } /** @@ -138,9 +144,14 @@ public static Async fromSync(Sync syncSpec) { Function> samplingHandler = r -> Mono .fromCallable(() -> syncSpec.samplingHandler().apply(r)) .subscribeOn(Schedulers.boundedElastic()); + + Function> elicitationHandler = r -> Mono + .fromCallable(() -> syncSpec.elicitationHandler().apply(r)) + .subscribeOn(Schedulers.boundedElastic()); + return new Async(syncSpec.clientInfo(), syncSpec.clientCapabilities(), syncSpec.roots(), toolsChangeConsumers, resourcesChangeConsumers, promptsChangeConsumers, loggingConsumers, - samplingHandler); + samplingHandler, elicitationHandler); } } @@ -156,13 +167,15 @@ public static Async fromSync(Sync syncSpec) { * @param promptsChangeConsumers the prompts change consumers. * @param loggingConsumers the logging consumers. * @param samplingHandler the sampling handler. + * @param elicitationHandler the elicitation handler. */ public record Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities clientCapabilities, Map roots, List>> toolsChangeConsumers, List>> resourcesChangeConsumers, List>> promptsChangeConsumers, List> loggingConsumers, - Function samplingHandler) { + Function samplingHandler, + Function elicitationHandler) { /** * Create an instance and validate the arguments. @@ -174,20 +187,23 @@ public record Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabili * @param promptsChangeConsumers the prompts change consumers. * @param loggingConsumers the logging consumers. * @param samplingHandler the sampling handler. + * @param elicitationHandler the elicitation handler. */ public Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities clientCapabilities, Map roots, List>> toolsChangeConsumers, List>> resourcesChangeConsumers, List>> promptsChangeConsumers, List> loggingConsumers, - Function samplingHandler) { + Function samplingHandler, + Function elicitationHandler) { Assert.notNull(clientInfo, "Client info must not be null"); this.clientInfo = clientInfo; this.clientCapabilities = (clientCapabilities != null) ? clientCapabilities : new McpSchema.ClientCapabilities(null, !Utils.isEmpty(roots) ? new McpSchema.ClientCapabilities.RootCapabilities(false) : null, - samplingHandler != null ? new McpSchema.ClientCapabilities.Sampling() : null); + samplingHandler != null ? new McpSchema.ClientCapabilities.Sampling() : null, + elicitationHandler != null ? new McpSchema.ClientCapabilities.Elicitation() : null); this.roots = roots != null ? new HashMap<>(roots) : new HashMap<>(); this.toolsChangeConsumers = toolsChangeConsumers != null ? toolsChangeConsumers : List.of(); @@ -195,6 +211,7 @@ public Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities cl this.promptsChangeConsumers = promptsChangeConsumers != null ? promptsChangeConsumers : List.of(); this.loggingConsumers = loggingConsumers != null ? loggingConsumers : List.of(); this.samplingHandler = samplingHandler; + this.elicitationHandler = elicitationHandler; } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java index 889dc66d0..cfb07d26c 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java @@ -36,6 +36,9 @@ public class McpAsyncServerExchange { private static final TypeReference LIST_ROOTS_RESULT_TYPE_REF = new TypeReference<>() { }; + private static final TypeReference ELICITATION_RESULT_TYPE_REF = new TypeReference<>() { + }; + /** * Create a new asynchronous exchange with the client. * @param session The server session representing a 1-1 interaction. @@ -93,6 +96,31 @@ public Mono createMessage(McpSchema.CreateMessage CREATE_MESSAGE_RESULT_TYPE_REF); } + /** + * Creates a new elicitation. MCP provides a standardized way for servers to request + * additional information from users through the client during interactions. This flow + * allows clients to maintain control over user interactions and data sharing while + * enabling servers to gather necessary information dynamically. Servers can request + * structured data from users with optional JSON schemas to validate responses. + * @param elicitRequest The request to create a new elicitation + * @return A Mono that completes when the elicitation has been resolved. + * @see McpSchema.ElicitRequest + * @see McpSchema.ElicitResult + * @see Elicitation + * Specification + */ + public Mono createElicitation(McpSchema.ElicitRequest elicitRequest) { + if (this.clientCapabilities == null) { + return Mono.error(new McpError("Client must be initialized. Call the initialize method first!")); + } + if (this.clientCapabilities.elicitation() == null) { + return Mono.error(new McpError("Client must be configured with elicitation capabilities")); + } + return this.session.sendRequest(McpSchema.METHOD_ELICITATION_CREATE, elicitRequest, + ELICITATION_RESULT_TYPE_REF); + } + /** * Retrieves the list of all roots provided by the client. * @return A Mono that emits the list of roots result. diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java index 52360e54b..084412b96 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java @@ -64,6 +64,24 @@ public McpSchema.CreateMessageResult createMessage(McpSchema.CreateMessageReques return this.exchange.createMessage(createMessageRequest).block(); } + /** + * Creates a new elicitation. MCP provides a standardized way for servers to request + * additional information from users through the client during interactions. This flow + * allows clients to maintain control over user interactions and data sharing while + * enabling servers to gather necessary information dynamically. Servers can request + * structured data from users with optional JSON schemas to validate responses. + * @param elicitRequest The request to create a new elicitation + * @return A result containing the elicitation response. + * @see McpSchema.ElicitRequest + * @see McpSchema.ElicitResult + * @see Elicitation + * Specification + */ + public McpSchema.ElicitResult createElicitation(McpSchema.ElicitRequest elicitRequest) { + return this.exchange.createElicitation(elicitRequest).block(); + } + /** * Retrieves the list of all roots provided by the client. * @return The list of roots result. diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java index 8df8a1584..9dae08266 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java @@ -94,6 +94,9 @@ private McpSchema() { // Sampling Methods public static final String METHOD_SAMPLING_CREATE_MESSAGE = "sampling/createMessage"; + // Elicitation Methods + public static final String METHOD_ELICITATION_CREATE = "elicitation/create"; + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); // --------------------------- @@ -131,8 +134,8 @@ public static final class ErrorCodes { } - public sealed interface Request - permits InitializeRequest, CallToolRequest, CreateMessageRequest, CompleteRequest, GetPromptRequest { + public sealed interface Request permits InitializeRequest, CallToolRequest, CreateMessageRequest, ElicitRequest, + CompleteRequest, GetPromptRequest { } @@ -221,7 +224,7 @@ public record JSONRPCError( public record InitializeRequest( // @formatter:off @JsonProperty("protocolVersion") String protocolVersion, @JsonProperty("capabilities") ClientCapabilities capabilities, - @JsonProperty("clientInfo") Implementation clientInfo) implements Request { + @JsonProperty("clientInfo") Implementation clientInfo) implements Request { } // @formatter:on @JsonInclude(JsonInclude.Include.NON_ABSENT) @@ -245,6 +248,8 @@ public record InitializeResult( // @formatter:off * access to. * @param sampling Provides a standardized way for servers to request LLM sampling * (“completions” or “generations”) from language models via clients. + * @param elicitation Provides a standardized way for servers to request additional + * information from users through the client during interactions. * */ @JsonInclude(JsonInclude.Include.NON_ABSENT) @@ -252,7 +257,8 @@ public record InitializeResult( // @formatter:off public record ClientCapabilities( // @formatter:off @JsonProperty("experimental") Map experimental, @JsonProperty("roots") RootCapabilities roots, - @JsonProperty("sampling") Sampling sampling) { + @JsonProperty("sampling") Sampling sampling, + @JsonProperty("elicitation") Elicitation elicitation) { /** * Roots define the boundaries of where servers can operate within the filesystem, @@ -264,7 +270,7 @@ public record ClientCapabilities( // @formatter:off * has changed since the last time the server checked. */ @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) + @JsonIgnoreProperties(ignoreUnknown = true) public record RootCapabilities( @JsonProperty("listChanged") Boolean listChanged) { } @@ -279,10 +285,22 @@ public record RootCapabilities( * image-based interactions and optionally include context * from MCP servers in their prompts. */ - @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonInclude(JsonInclude.Include.NON_ABSENT) public record Sampling() { } + /** + * Provides a standardized way for servers to request additional + * information from users through the client during interactions. + * This flow allows clients to maintain control over user + * interactions and data sharing while enabling servers to gather + * necessary information dynamically. Servers can request structured + * data from users with optional JSON schemas to validate responses. + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + public record Elicitation() { + } + public static Builder builder() { return new Builder(); } @@ -291,6 +309,7 @@ public static class Builder { private Map experimental; private RootCapabilities roots; private Sampling sampling; + private Elicitation elicitation; public Builder experimental(Map experimental) { this.experimental = experimental; @@ -307,8 +326,13 @@ public Builder sampling() { return this; } + public Builder elicitation() { + this.elicitation = new Elicitation(); + return this; + } + public ClientCapabilities build() { - return new ClientCapabilities(experimental, roots, sampling); + return new ClientCapabilities(experimental, roots, sampling, elicitation); } } }// @formatter:on @@ -326,11 +350,11 @@ public record ServerCapabilities( // @formatter:off @JsonInclude(JsonInclude.Include.NON_ABSENT) public record CompletionCapabilities() { } - + @JsonInclude(JsonInclude.Include.NON_ABSENT) public record LoggingCapabilities() { } - + @JsonInclude(JsonInclude.Include.NON_ABSENT) public record PromptCapabilities( @JsonProperty("listChanged") Boolean listChanged) { @@ -727,11 +751,11 @@ public record Tool( // @formatter:off @JsonProperty("name") String name, @JsonProperty("description") String description, @JsonProperty("inputSchema") JsonSchema inputSchema) { - + public Tool(String name, String description, String schema) { this(name, description, parseSchema(schema)); } - + } // @formatter:on private static JsonSchema parseSchema(String schema) { @@ -758,7 +782,7 @@ public record CallToolRequest(// @formatter:off @JsonProperty("arguments") Map arguments) implements Request { public CallToolRequest(String name, String jsonArguments) { - this(name, parseJsonArguments(jsonArguments)); + this(name, parseJsonArguments(jsonArguments)); } private static Map parseJsonArguments(String jsonArguments) { @@ -893,7 +917,7 @@ public record ModelPreferences(// @formatter:off @JsonProperty("costPriority") Double costPriority, @JsonProperty("speedPriority") Double speedPriority, @JsonProperty("intelligencePriority") Double intelligencePriority) { - + public static Builder builder() { return new Builder(); } @@ -963,7 +987,7 @@ public record CreateMessageRequest(// @formatter:off @JsonProperty("includeContext") ContextInclusionStrategy includeContext, @JsonProperty("temperature") Double temperature, @JsonProperty("maxTokens") int maxTokens, - @JsonProperty("stopSequences") List stopSequences, + @JsonProperty("stopSequences") List stopSequences, @JsonProperty("metadata") Map metadata) implements Request { public enum ContextInclusionStrategy { @@ -971,7 +995,7 @@ public enum ContextInclusionStrategy { @JsonProperty("thisServer") THIS_SERVER, @JsonProperty("allServers") ALL_SERVERS } - + public static Builder builder() { return new Builder(); } @@ -1040,7 +1064,7 @@ public record CreateMessageResult(// @formatter:off @JsonProperty("content") Content content, @JsonProperty("model") String model, @JsonProperty("stopReason") StopReason stopReason) { - + public enum StopReason { @JsonProperty("endTurn") END_TURN, @JsonProperty("stopSequence") STOP_SEQUENCE, @@ -1088,6 +1112,79 @@ public CreateMessageResult build() { } }// @formatter:on + // Elicitation + /** + * Used by the server to send an elicitation to the client. + * + * @param message The body of the elicitation message. + * @param requestedSchema The elicitation response schema that must be satisfied. + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ElicitRequest(// @formatter:off + @JsonProperty("message") String message, + @JsonProperty("requestedSchema") Map requestedSchema) implements Request { + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private String message; + private Map requestedSchema; + + public Builder message(String message) { + this.message = message; + return this; + } + + public Builder requestedSchema(Map requestedSchema) { + this.requestedSchema = requestedSchema; + return this; + } + + public ElicitRequest build() { + return new ElicitRequest(message, requestedSchema); + } + } + }// @formatter:on + + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ElicitResult(// @formatter:off + @JsonProperty("action") Action action, + @JsonProperty("content") Map content) { + + public enum Action { + @JsonProperty("accept") ACCEPT, + @JsonProperty("decline") DECLINE, + @JsonProperty("cancel") CANCEL + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private Action action; + private Map content; + + public Builder message(Action action) { + this.action = action; + return this; + } + + public Builder content(Map content) { + this.content = content; + return this; + } + + public ElicitResult build() { + return new ElicitResult(action, content); + } + } + }// @formatter:on + // --------------------------- // Pagination Interfaces // --------------------------- diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java index 72b409af9..d1a2581ee 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java @@ -19,6 +19,8 @@ import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.ElicitRequest; +import io.modelcontextprotocol.spec.McpSchema.ElicitResult; import io.modelcontextprotocol.spec.McpSchema.GetPromptRequest; import io.modelcontextprotocol.spec.McpSchema.Prompt; import io.modelcontextprotocol.spec.McpSchema.Resource; @@ -422,6 +424,20 @@ void testInitializeWithSamplingCapability() { }); } + @Test + void testInitializeWithElicitationCapability() { + ClientCapabilities capabilities = ClientCapabilities.builder().elicitation().build(); + ElicitResult elicitResult = ElicitResult.builder() + .message(ElicitResult.Action.ACCEPT) + .content(Map.of("foo", "bar")) + .build(); + withClient(createMcpTransport(), + builder -> builder.capabilities(capabilities).elicitation(request -> Mono.just(elicitResult)), + client -> { + StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete(); + }); + } + @Test void testInitializeWithAllCapabilities() { var capabilities = ClientCapabilities.builder() @@ -433,7 +449,11 @@ void testInitializeWithAllCapabilities() { Function> samplingHandler = request -> Mono .just(CreateMessageResult.builder().message("test").model("test-model").build()); - withClient(createMcpTransport(), builder -> builder.capabilities(capabilities).sampling(samplingHandler), + Function> elicitationHandler = request -> Mono + .just(ElicitResult.builder().message(ElicitResult.Action.ACCEPT).content(Map.of("foo", "bar")).build()); + + withClient(createMcpTransport(), + builder -> builder.capabilities(capabilities).sampling(samplingHandler).elicitation(elicitationHandler), client -> StepVerifier.create(client.initialize()).assertNext(result -> { diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java index 4510b1529..e6cde8e3b 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java @@ -19,6 +19,8 @@ import io.modelcontextprotocol.spec.McpSchema.InitializeResult; import io.modelcontextprotocol.spec.McpSchema.Root; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.EnumSource; import reactor.core.publisher.Mono; import static io.modelcontextprotocol.spec.McpSchema.METHOD_INITIALIZE; @@ -349,4 +351,152 @@ void testSamplingCreateMessageRequestHandlingWithNullHandler() { .hasMessage("Sampling handler must not be null when client capabilities include sampling"); } + @Test + @SuppressWarnings("unchecked") + void testElicitationCreateRequestHandling() { + MockMcpClientTransport transport = initializationEnabledTransport(); + + // Create a test elicitation handler that echoes back the input + Function> elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isInstanceOf(Map.class); + assertThat(request.requestedSchema().get("type")).isEqualTo("object"); + + var properties = request.requestedSchema().get("properties"); + assertThat(properties).isNotNull(); + assertThat(((Map) properties).get("message")).isInstanceOf(Map.class); + + return Mono.just(McpSchema.ElicitResult.builder() + .message(McpSchema.ElicitResult.Action.ACCEPT) + .content(Map.of("message", request.message())) + .build()); + }; + + // Create client with elicitation capability and handler + McpAsyncClient asyncMcpClient = McpClient.async(transport) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build(); + + assertThat(asyncMcpClient.initialize().block()).isNotNull(); + + // Create a mock elicitation + var elicitRequest = McpSchema.ElicitRequest.builder() + .message("Test message") + .requestedSchema(Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + // Simulate incoming request + McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, + McpSchema.METHOD_ELICITATION_CREATE, "test-id", elicitRequest); + transport.simulateIncomingMessage(request); + + // Verify response + McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessage(); + assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class); + + McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; + assertThat(response.id()).isEqualTo("test-id"); + assertThat(response.error()).isNull(); + + McpSchema.ElicitResult result = transport.unmarshalFrom(response.result(), new TypeReference<>() { + }); + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); + assertThat(result.content()).isEqualTo(Map.of("message", "Test message")); + + asyncMcpClient.closeGracefully(); + } + + @ParameterizedTest + @EnumSource(value = McpSchema.ElicitResult.Action.class, names = { "DECLINE", "CANCEL" }) + void testElicitationFailRequestHandling(McpSchema.ElicitResult.Action action) { + MockMcpClientTransport transport = initializationEnabledTransport(); + + // Create a test elicitation handler to decline the request + Function> elicitationHandler = request -> Mono + .just(McpSchema.ElicitResult.builder().message(action).build()); + + // Create client with elicitation capability and handler + McpAsyncClient asyncMcpClient = McpClient.async(transport) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build(); + + assertThat(asyncMcpClient.initialize().block()).isNotNull(); + + // Create a mock elicitation + var elicitRequest = McpSchema.ElicitRequest.builder() + .message("Test message") + .requestedSchema(Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + // Simulate incoming request + McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, + McpSchema.METHOD_ELICITATION_CREATE, "test-id", elicitRequest); + transport.simulateIncomingMessage(request); + + // Verify response + McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessage(); + assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class); + + McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; + assertThat(response.id()).isEqualTo("test-id"); + assertThat(response.error()).isNull(); + + McpSchema.ElicitResult result = transport.unmarshalFrom(response.result(), new TypeReference<>() { + }); + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(action); + assertThat(result.content()).isNull(); + + asyncMcpClient.closeGracefully(); + } + + @Test + void testElicitationCreateRequestHandlingWithoutCapability() { + MockMcpClientTransport transport = initializationEnabledTransport(); + + // Create client without elicitation capability + McpAsyncClient asyncMcpClient = McpClient.async(transport) + .capabilities(ClientCapabilities.builder().build()) // No elicitation + // capability + .build(); + + assertThat(asyncMcpClient.initialize().block()).isNotNull(); + + // Create a mock elicitation + var elicitRequest = new McpSchema.ElicitRequest("test", + Map.of("type", "object", "properties", Map.of("test", Map.of("type", "boolean", "defaultValue", true, + "description", "test-description", "title", "test-title")))); + + // Simulate incoming request + McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, + McpSchema.METHOD_ELICITATION_CREATE, "test-id", elicitRequest); + transport.simulateIncomingMessage(request); + + // Verify error response + McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessage(); + assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class); + + McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; + assertThat(response.id()).isEqualTo("test-id"); + assertThat(response.result()).isNull(); + assertThat(response.error()).isNotNull(); + assertThat(response.error().message()).contains("Method not found: elicitation/create"); + + asyncMcpClient.closeGracefully(); + } + + @Test + void testElicitationCreateRequestHandlingWithNullHandler() { + MockMcpClientTransport transport = new MockMcpClientTransport(); + + // Create client with elicitation capability but null handler + assertThatThrownBy(() -> McpClient.async(transport) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .build()).isInstanceOf(McpError.class) + .hasMessage("Elicitation handler must not be null when client capabilities include elicitation"); + } + } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java index 2ff6325a4..dc9d1cfab 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java @@ -24,6 +24,8 @@ import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.ElicitRequest; +import io.modelcontextprotocol.spec.McpSchema.ElicitResult; import io.modelcontextprotocol.spec.McpSchema.InitializeResult; import io.modelcontextprotocol.spec.McpSchema.ModelPreferences; import io.modelcontextprotocol.spec.McpSchema.Role; @@ -339,6 +341,217 @@ void testCreateMessageWithRequestTimeoutFail() throws InterruptedException { mcpServer.close(); } + // --------------------------------------- + // Elicitation Tests + // --------------------------------------- + @Test + @Disabled + void testCreateElicitationWithoutElicitationCapabilities() { + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + exchange.createElicitation(mock(ElicitRequest.class)).block(); + + return Mono.just(mock(CallToolResult.class)); + }); + + var server = McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").tools(tool).build(); + + try ( + // Create client without elicitation capabilities + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")).build()) { + + assertThat(client.initialize()).isNotNull(); + + try { + client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be configured with elicitation capabilities"); + } + } + server.closeGracefully().block(); + } + + @Test + void testCreateElicitationSuccess() { + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + + return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); + }; + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var elicitationRequest = ElicitRequest.builder() + .message("Test message") + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + } + mcpServer.closeGracefully().block(); + } + + @Test + void testCreateElicitationWithRequestTimeoutSuccess() { + + // Client + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + try { + TimeUnit.SECONDS.sleep(2); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build(); + + // Server + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var elicitationRequest = ElicitRequest.builder() + .message("Test message") + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(3)) + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + mcpClient.closeGracefully(); + mcpServer.closeGracefully().block(); + } + + @Test + void testCreateElicitationWithRequestTimeoutFail() { + + // Client + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + try { + TimeUnit.SECONDS.sleep(2); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build(); + + // Server + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var elicitationRequest = ElicitRequest.builder() + .message("Test message") + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(1)) + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThatExceptionOfType(McpError.class).isThrownBy(() -> { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + }).withMessageContaining("Timeout"); + + mcpClient.closeGracefully(); + mcpServer.closeGracefully().block(); + } + // --------------------------------------- // Roots Tests // --------------------------------------- diff --git a/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java b/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java index ff78c1bfc..99015d8c4 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java @@ -807,6 +807,40 @@ void testCreateMessageResult() throws Exception { {"role":"assistant","content":{"type":"text","text":"Assistant response"},"model":"gpt-4","stopReason":"endTurn"}""")); } + // Elicitation Tests + + @Test + void testCreateElicitationRequest() throws Exception { + McpSchema.ElicitRequest request = McpSchema.ElicitRequest.builder() + .requestedSchema(Map.of("type", "object", "required", List.of("a"), "properties", + Map.of("foo", Map.of("type", "string")))) + .build(); + + String value = mapper.writeValueAsString(request); + + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"requestedSchema":{"properties":{"foo":{"type":"string"}},"required":["a"],"type":"object"}}""")); + } + + @Test + void testCreateElicitationResult() throws Exception { + McpSchema.ElicitResult result = McpSchema.ElicitResult.builder() + .content(Map.of("foo", "bar")) + .message(McpSchema.ElicitResult.Action.ACCEPT) + .build(); + + String value = mapper.writeValueAsString(result); + + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"action":"accept","content":{"foo":"bar"}}""")); + } + // Roots Tests @Test From 2f944349cc77009d020ebddc8b9967b8e1b14baf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Tue, 10 Jun 2025 18:33:40 +0200 Subject: [PATCH 053/205] feat: WebClient Streamable HTTP support (#292) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit An implementation of Streamable HTTP Client with WebFlux WebClient. Aside from implementing the specification, several improvements have been incorporated throughout the client-side of the architecture. The changes cover: - resilience tests using toxiproxy in testcontainers - integration tests using updated everything-server with streamableHttp support - improved logging - session invalidation handling (both transport session and JSON-RPC concept of session) - implicit initialization and burst protection (in case of concurrent `Mcp(Sync|Async)Client` use - more logging, e.g. stdio process lifecycle logs Related #72, #273, #253, #107, #105 Signed-off-by: Dariusz Jędrzejczyk --- mcp-spring/mcp-spring-webflux/pom.xml | 6 + .../WebClientStreamableHttpTransport.java | 520 ++++++++++++++++++ .../transport/WebFluxSseClientTransport.java | 3 + ...eamableHttpAsyncClientResiliencyTests.java | 17 + ...bClientStreamableHttpAsyncClientTests.java | 42 ++ ...ebClientStreamableHttpSyncClientTests.java | 41 ++ .../client/WebFluxSseMcpAsyncClientTests.java | 3 +- .../client/WebFluxSseMcpSyncClientTests.java | 3 +- .../WebFluxSseClientTransportTests.java | 3 +- .../src/test/resources/logback.xml | 8 +- mcp-test/pom.xml | 5 + ...AbstractMcpAsyncClientResiliencyTests.java | 222 ++++++++ .../client/AbstractMcpAsyncClientTests.java | 37 +- .../client/AbstractMcpSyncClientTests.java | 60 +- .../client/McpAsyncClient.java | 291 ++++++---- .../transport/StdioClientTransport.java | 5 + .../spec/DefaultMcpTransportSession.java | 79 +++ .../spec/DefaultMcpTransportStream.java | 74 +++ .../spec/McpClientSession.java | 53 +- .../spec/McpClientTransport.java | 22 +- .../modelcontextprotocol/spec/McpSchema.java | 6 + .../spec/McpTransportSession.java | 60 ++ .../McpTransportSessionNotFoundException.java | 29 + .../spec/McpTransportStream.java | 45 ++ .../client/AbstractMcpAsyncClientTests.java | 37 +- .../client/AbstractMcpSyncClientTests.java | 58 +- .../client/HttpSseMcpAsyncClientTests.java | 3 +- .../client/HttpSseMcpSyncClientTests.java | 3 +- .../client/StdioMcpSyncClientTests.java | 2 +- .../HttpClientSseClientTransportTests.java | 3 +- pom.xml | 3 +- 31 files changed, 1501 insertions(+), 242 deletions(-) create mode 100644 mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java create mode 100644 mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientResiliencyTests.java create mode 100644 mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientTests.java create mode 100644 mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpSyncClientTests.java create mode 100644 mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportSession.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportStream.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportSession.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportSessionNotFoundException.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportStream.java diff --git a/mcp-spring/mcp-spring-webflux/pom.xml b/mcp-spring/mcp-spring-webflux/pom.xml index a8b92bd09..26452fe95 100644 --- a/mcp-spring/mcp-spring-webflux/pom.xml +++ b/mcp-spring/mcp-spring-webflux/pom.xml @@ -99,6 +99,12 @@ ${testcontainers.version} test + + org.testcontainers + toxiproxy + ${toxiproxy.version} + test + org.awaitility diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java new file mode 100644 index 000000000..e7b7c8ee9 --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java @@ -0,0 +1,520 @@ +package io.modelcontextprotocol.client.transport; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.DefaultMcpTransportSession; +import io.modelcontextprotocol.spec.DefaultMcpTransportStream; +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException; +import io.modelcontextprotocol.spec.McpTransportSession; +import io.modelcontextprotocol.spec.McpTransportStream; +import io.modelcontextprotocol.util.Assert; +import org.reactivestreams.Publisher; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.codec.ServerSentEvent; +import org.springframework.web.reactive.function.client.ClientResponse; +import org.springframework.web.reactive.function.client.WebClient; +import org.springframework.web.reactive.function.client.WebClientResponseException; +import reactor.core.Disposable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.util.function.Tuple2; +import reactor.util.function.Tuples; + +import java.io.IOException; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.Supplier; + +/** + * An implementation of the Streamable HTTP protocol as defined by the + * 2025-03-26 version of the MCP specification. + * + *

+ * The transport is capable of resumability and reconnects. It reacts to transport-level + * session invalidation and will propagate {@link McpTransportSessionNotFoundException + * appropriate exceptions} to the higher level abstraction layer when needed in order to + * allow proper state management. The implementation handles servers that are stateful and + * provide session meta information, but can also communicate with stateless servers that + * do not provide a session identifier and do not support SSE streams. + *

+ *

+ * This implementation does not handle backwards compatibility with the "HTTP + * with SSE" transport. In order to communicate over the phased-out + * 2024-11-05 protocol, use {@link HttpClientSseClientTransport} or + * {@link WebFluxSseClientTransport}. + *

+ * + * @author Dariusz Jędrzejczyk + * @see Streamable + * HTTP transport specification + */ +public class WebClientStreamableHttpTransport implements McpClientTransport { + + private static final Logger logger = LoggerFactory.getLogger(WebClientStreamableHttpTransport.class); + + private static final String DEFAULT_ENDPOINT = "/mcp"; + + /** + * Event type for JSON-RPC messages received through the SSE connection. The server + * sends messages with this event type to transmit JSON-RPC protocol data. + */ + private static final String MESSAGE_EVENT_TYPE = "message"; + + private static final ParameterizedTypeReference> PARAMETERIZED_TYPE_REF = new ParameterizedTypeReference<>() { + }; + + private final ObjectMapper objectMapper; + + private final WebClient webClient; + + private final String endpoint; + + private final boolean openConnectionOnStartup; + + private final boolean resumableStreams; + + private final AtomicReference activeSession = new AtomicReference<>(); + + private final AtomicReference, Mono>> handler = new AtomicReference<>(); + + private final AtomicReference> exceptionHandler = new AtomicReference<>(); + + private WebClientStreamableHttpTransport(ObjectMapper objectMapper, WebClient.Builder webClientBuilder, + String endpoint, boolean resumableStreams, boolean openConnectionOnStartup) { + this.objectMapper = objectMapper; + this.webClient = webClientBuilder.build(); + this.endpoint = endpoint; + this.resumableStreams = resumableStreams; + this.openConnectionOnStartup = openConnectionOnStartup; + this.activeSession.set(createTransportSession()); + } + + /** + * Create a stateful builder for creating {@link WebClientStreamableHttpTransport} + * instances. + * @param webClientBuilder the {@link WebClient.Builder} to use + * @return a builder which will create an instance of + * {@link WebClientStreamableHttpTransport} once {@link Builder#build()} is called + */ + public static Builder builder(WebClient.Builder webClientBuilder) { + return new Builder(webClientBuilder); + } + + @Override + public Mono connect(Function, Mono> handler) { + return Mono.deferContextual(ctx -> { + this.handler.set(handler); + if (openConnectionOnStartup) { + logger.debug("Eagerly opening connection on startup"); + return this.reconnect(null).then(); + } + return Mono.empty(); + }); + } + + private DefaultMcpTransportSession createTransportSession() { + Supplier> onClose = () -> { + DefaultMcpTransportSession transportSession = this.activeSession.get(); + return transportSession.sessionId().isEmpty() ? Mono.empty() + : webClient.delete().uri(this.endpoint).headers(httpHeaders -> { + httpHeaders.add("mcp-session-id", transportSession.sessionId().get()); + }).retrieve().toBodilessEntity().doOnError(e -> logger.info("Got response {}", e)).then(); + }; + return new DefaultMcpTransportSession(onClose); + } + + @Override + public void setExceptionHandler(Consumer handler) { + logger.debug("Exception handler registered"); + this.exceptionHandler.set(handler); + } + + private void handleException(Throwable t) { + logger.debug("Handling exception for session {}", sessionIdOrPlaceholder(this.activeSession.get()), t); + if (t instanceof McpTransportSessionNotFoundException) { + McpTransportSession invalidSession = this.activeSession.getAndSet(createTransportSession()); + logger.warn("Server does not recognize session {}. Invalidating.", invalidSession.sessionId()); + invalidSession.close(); + } + Consumer handler = this.exceptionHandler.get(); + if (handler != null) { + handler.accept(t); + } + } + + @Override + public Mono closeGracefully() { + return Mono.defer(() -> { + logger.debug("Graceful close triggered"); + DefaultMcpTransportSession currentSession = this.activeSession.getAndSet(createTransportSession()); + if (currentSession != null) { + return currentSession.closeGracefully(); + } + return Mono.empty(); + }); + } + + private Mono reconnect(McpTransportStream stream) { + return Mono.deferContextual(ctx -> { + if (stream != null) { + logger.debug("Reconnecting stream {} with lastId {}", stream.streamId(), stream.lastId()); + } + else { + logger.debug("Reconnecting with no prior stream"); + } + // Here we attempt to initialize the client. In case the server supports SSE, + // we will establish a long-running + // session here and listen for messages. If it doesn't, that's ok, the server + // is a simple, stateless one. + final AtomicReference disposableRef = new AtomicReference<>(); + final McpTransportSession transportSession = this.activeSession.get(); + + Disposable connection = webClient.get() + .uri(this.endpoint) + .accept(MediaType.TEXT_EVENT_STREAM) + .headers(httpHeaders -> { + transportSession.sessionId().ifPresent(id -> httpHeaders.add("mcp-session-id", id)); + if (stream != null) { + stream.lastId().ifPresent(id -> httpHeaders.add("last-event-id", id)); + } + }) + .exchangeToFlux(response -> { + if (isEventStream(response)) { + return eventStream(stream, response); + } + else if (isNotAllowed(response)) { + logger.debug("The server does not support SSE streams, using request-response mode."); + return Flux.empty(); + } + else if (isNotFound(response)) { + String sessionIdRepresentation = sessionIdOrPlaceholder(transportSession); + return mcpSessionNotFoundError(sessionIdRepresentation); + } + else { + return response.createError().doOnError(e -> { + logger.info("Opening an SSE stream failed. This can be safely ignored.", e); + }).flux(); + } + }) + .onErrorComplete(t -> { + this.handleException(t); + return true; + }) + .doFinally(s -> { + Disposable ref = disposableRef.getAndSet(null); + if (ref != null) { + transportSession.removeConnection(ref); + } + }) + .contextWrite(ctx) + .subscribe(); + + disposableRef.set(connection); + transportSession.addConnection(connection); + return Mono.just(connection); + }); + } + + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message) { + return Mono.create(sink -> { + logger.debug("Sending message {}", message); + // Here we attempt to initialize the client. + // In case the server supports SSE, we will establish a long-running session + // here and + // listen for messages. + // If it doesn't, nothing actually happens here, that's just the way it is... + final AtomicReference disposableRef = new AtomicReference<>(); + final McpTransportSession transportSession = this.activeSession.get(); + + Disposable connection = webClient.post() + .uri(this.endpoint) + .accept(MediaType.TEXT_EVENT_STREAM, MediaType.APPLICATION_JSON) + .headers(httpHeaders -> { + transportSession.sessionId().ifPresent(id -> httpHeaders.add("mcp-session-id", id)); + }) + .bodyValue(message) + .exchangeToFlux(response -> { + if (transportSession + .markInitialized(response.headers().asHttpHeaders().getFirst("mcp-session-id"))) { + // Once we have a session, we try to open an async stream for + // the server to send notifications and requests out-of-band. + reconnect(null).contextWrite(sink.contextView()).subscribe(); + } + + String sessionRepresentation = sessionIdOrPlaceholder(transportSession); + + // The spec mentions only ACCEPTED, but the existing SDKs can return + // 200 OK for notifications + if (response.statusCode().is2xxSuccessful()) { + Optional contentType = response.headers().contentType(); + // Existing SDKs consume notifications with no response body nor + // content type + if (contentType.isEmpty()) { + logger.trace("Message was successfully sent via POST for session {}", + sessionRepresentation); + // signal the caller that the message was successfully + // delivered + sink.success(); + // communicate to downstream there is no streamed data coming + return Flux.empty(); + } + else { + MediaType mediaType = contentType.get(); + if (mediaType.isCompatibleWith(MediaType.TEXT_EVENT_STREAM)) { + // communicate to caller that the message was delivered + sink.success(); + // starting a stream + return newEventStream(response, sessionRepresentation); + } + else if (mediaType.isCompatibleWith(MediaType.APPLICATION_JSON)) { + logger.trace("Received response to POST for session {}", sessionRepresentation); + // communicate to caller the message was delivered + sink.success(); + return responseFlux(response); + } + else { + logger.warn("Unknown media type {} returned for POST in session {}", contentType, + sessionRepresentation); + return Flux.error(new RuntimeException("Unknown media type returned: " + contentType)); + } + } + } + else { + if (isNotFound(response)) { + return mcpSessionNotFoundError(sessionRepresentation); + } + return extractError(response, sessionRepresentation); + } + }) + .flatMap(jsonRpcMessage -> this.handler.get().apply(Mono.just(jsonRpcMessage))) + .onErrorResume(t -> { + // handle the error first + this.handleException(t); + // inform the caller of sendMessage + sink.error(t); + return Flux.empty(); + }) + .doFinally(s -> { + Disposable ref = disposableRef.getAndSet(null); + if (ref != null) { + transportSession.removeConnection(ref); + } + }) + .contextWrite(sink.contextView()) + .subscribe(); + disposableRef.set(connection); + transportSession.addConnection(connection); + }); + } + + private static Flux mcpSessionNotFoundError(String sessionRepresentation) { + logger.warn("Session {} was not found on the MCP server", sessionRepresentation); + // inform the stream/connection subscriber + return Flux.error(new McpTransportSessionNotFoundException(sessionRepresentation)); + } + + private Flux extractError(ClientResponse response, String sessionRepresentation) { + return response.createError().onErrorResume(e -> { + WebClientResponseException responseException = (WebClientResponseException) e; + byte[] body = responseException.getResponseBodyAsByteArray(); + McpSchema.JSONRPCResponse.JSONRPCError jsonRpcError = null; + Exception toPropagate; + try { + McpSchema.JSONRPCResponse jsonRpcResponse = objectMapper.readValue(body, + McpSchema.JSONRPCResponse.class); + jsonRpcError = jsonRpcResponse.error(); + toPropagate = new McpError(jsonRpcError); + } + catch (IOException ex) { + toPropagate = new RuntimeException("Sending request failed", e); + logger.debug("Received content together with {} HTTP code response: {}", response.statusCode(), body); + } + + // Some implementations can return 400 when presented with a + // session id that it doesn't know about, so we will + // invalidate the session + // https://github.com/modelcontextprotocol/typescript-sdk/issues/389 + if (responseException.getStatusCode().isSameCodeAs(HttpStatus.BAD_REQUEST)) { + return Mono.error(new McpTransportSessionNotFoundException(sessionRepresentation, toPropagate)); + } + return Mono.empty(); + }).flux(); + } + + private Flux eventStream(McpTransportStream stream, ClientResponse response) { + McpTransportStream sessionStream = stream != null ? stream + : new DefaultMcpTransportStream<>(this.resumableStreams, this::reconnect); + logger.debug("Connected stream {}", sessionStream.streamId()); + + var idWithMessages = response.bodyToFlux(PARAMETERIZED_TYPE_REF).map(this::parse); + return Flux.from(sessionStream.consumeSseStream(idWithMessages)); + } + + private static boolean isNotFound(ClientResponse response) { + return response.statusCode().isSameCodeAs(HttpStatus.NOT_FOUND); + } + + private static boolean isNotAllowed(ClientResponse response) { + return response.statusCode().isSameCodeAs(HttpStatus.METHOD_NOT_ALLOWED); + } + + private static boolean isEventStream(ClientResponse response) { + return response.statusCode().is2xxSuccessful() && response.headers().contentType().isPresent() + && response.headers().contentType().get().isCompatibleWith(MediaType.TEXT_EVENT_STREAM); + } + + private static String sessionIdOrPlaceholder(McpTransportSession transportSession) { + return transportSession.sessionId().orElse("[missing_session_id]"); + } + + private Flux responseFlux(ClientResponse response) { + return response.bodyToMono(String.class).>handle((responseMessage, s) -> { + try { + McpSchema.JSONRPCMessage jsonRpcResponse = McpSchema.deserializeJsonRpcMessage(objectMapper, + responseMessage); + s.next(List.of(jsonRpcResponse)); + } + catch (IOException e) { + s.error(e); + } + }).flatMapIterable(Function.identity()); + } + + private Flux newEventStream(ClientResponse response, String sessionRepresentation) { + McpTransportStream sessionStream = new DefaultMcpTransportStream<>(this.resumableStreams, + this::reconnect); + logger.trace("Sent POST and opened a stream ({}) for session {}", sessionStream.streamId(), + sessionRepresentation); + return eventStream(sessionStream, response); + } + + @Override + public T unmarshalFrom(Object data, TypeReference typeRef) { + return this.objectMapper.convertValue(data, typeRef); + } + + private Tuple2, Iterable> parse(ServerSentEvent event) { + if (MESSAGE_EVENT_TYPE.equals(event.event())) { + try { + // We don't support batching ATM and probably won't since the next version + // considers removing it. + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.objectMapper, event.data()); + return Tuples.of(Optional.ofNullable(event.id()), List.of(message)); + } + catch (IOException ioException) { + throw new McpError("Error parsing JSON-RPC message: " + event.data()); + } + } + else { + throw new McpError("Received unrecognized SSE event type: " + event.event()); + } + } + + /** + * Builder for {@link WebClientStreamableHttpTransport}. + */ + public static class Builder { + + private ObjectMapper objectMapper; + + private WebClient.Builder webClientBuilder; + + private String endpoint = DEFAULT_ENDPOINT; + + private boolean resumableStreams = true; + + private boolean openConnectionOnStartup = false; + + private Builder(WebClient.Builder webClientBuilder) { + Assert.notNull(webClientBuilder, "WebClient.Builder must not be null"); + this.webClientBuilder = webClientBuilder; + } + + /** + * Configure the {@link ObjectMapper} to use. + * @param objectMapper instance to use + * @return the builder instance + */ + public Builder objectMapper(ObjectMapper objectMapper) { + Assert.notNull(objectMapper, "ObjectMapper must not be null"); + this.objectMapper = objectMapper; + return this; + } + + /** + * Configure the {@link WebClient.Builder} to construct the {@link WebClient}. + * @param webClientBuilder instance to use + * @return the builder instance + */ + public Builder webClientBuilder(WebClient.Builder webClientBuilder) { + Assert.notNull(webClientBuilder, "WebClient.Builder must not be null"); + this.webClientBuilder = webClientBuilder; + return this; + } + + /** + * Configure the endpoint to make HTTP requests against. + * @param endpoint endpoint to use + * @return the builder instance + */ + public Builder endpoint(String endpoint) { + Assert.hasText(endpoint, "endpoint must be a non-empty String"); + this.endpoint = endpoint; + return this; + } + + /** + * Configure whether to use the stream resumability feature by keeping track of + * SSE event ids. + * @param resumableStreams if {@code true} event ids will be tracked and upon + * disconnection, the last seen id will be used upon reconnection as a header to + * resume consuming messages. + * @return the builder instance + */ + public Builder resumableStreams(boolean resumableStreams) { + this.resumableStreams = resumableStreams; + return this; + } + + /** + * Configure whether the client should open an SSE connection upon startup. Not + * all servers support this (although it is in theory possible with the current + * specification), so use with caution. By default, this value is {@code false}. + * @param openConnectionOnStartup if {@code true} the {@link #connect(Function)} + * method call will try to open an SSE connection before sending any JSON-RPC + * request + * @return the builder instance + */ + public Builder openConnectionOnStartup(boolean openConnectionOnStartup) { + this.openConnectionOnStartup = openConnectionOnStartup; + return this; + } + + /** + * Construct a fresh instance of {@link WebClientStreamableHttpTransport} using + * the current builder configuration. + * @return a new instance of {@link WebClientStreamableHttpTransport} + */ + public WebClientStreamableHttpTransport build() { + ObjectMapper objectMapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); + + return new WebClientStreamableHttpTransport(objectMapper, this.webClientBuilder, endpoint, resumableStreams, + openConnectionOnStartup); + } + + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java index 37abe295b..128cda4c3 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java @@ -190,6 +190,9 @@ public WebFluxSseClientTransport(WebClient.Builder webClientBuilder, ObjectMappe */ @Override public Mono connect(Function, Mono> handler) { + // TODO: Avoid eager connection opening and enable resilience + // -> upon disconnects, re-establish connection + // -> allow optimizing for eager connection start using a constructor flag Flux> events = eventStream(); this.inboundSubscription = events.concatMap(event -> Mono.just(event).handle((e, s) -> { if (ENDPOINT_EVENT_TYPE.equals(event.event())) { diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientResiliencyTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientResiliencyTests.java new file mode 100644 index 000000000..80fc671e2 --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientResiliencyTests.java @@ -0,0 +1,17 @@ +package io.modelcontextprotocol.client; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; +import org.junit.jupiter.api.Timeout; +import org.springframework.web.reactive.function.client.WebClient; + +@Timeout(15) +public class WebClientStreamableHttpAsyncClientResiliencyTests extends AbstractMcpAsyncClientResiliencyTests { + + @Override + protected McpClientTransport createMcpTransport() { + return WebClientStreamableHttpTransport.builder(WebClient.builder().baseUrl(host)).build(); + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientTests.java new file mode 100644 index 000000000..4c8032659 --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientTests.java @@ -0,0 +1,42 @@ +package io.modelcontextprotocol.client; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; +import org.junit.jupiter.api.Timeout; +import org.springframework.web.reactive.function.client.WebClient; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.wait.strategy.Wait; +import org.testcontainers.images.builder.ImageFromDockerfile; + +@Timeout(15) +public class WebClientStreamableHttpAsyncClientTests extends AbstractMcpAsyncClientTests { + + static String host = "http://localhost:3001"; + + // Uses the https://github.com/tzolov/mcp-everything-server-docker-image + @SuppressWarnings("resource") + GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") + .withCommand("node dist/index.js streamableHttp") + .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) + .withExposedPorts(3001) + .waitingFor(Wait.forHttp("/").forStatusCode(404)); + + @Override + protected McpClientTransport createMcpTransport() { + return WebClientStreamableHttpTransport.builder(WebClient.builder().baseUrl(host)).build(); + } + + @Override + protected void onStart() { + container.start(); + int port = container.getMappedPort(3001); + host = "http://" + container.getHost() + ":" + port; + } + + @Override + public void onClose() { + container.stop(); + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpSyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpSyncClientTests.java new file mode 100644 index 000000000..a8cad4898 --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpSyncClientTests.java @@ -0,0 +1,41 @@ +package io.modelcontextprotocol.client; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; +import org.junit.jupiter.api.Timeout; +import org.springframework.web.reactive.function.client.WebClient; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.wait.strategy.Wait; + +@Timeout(15) +public class WebClientStreamableHttpSyncClientTests extends AbstractMcpSyncClientTests { + + static String host = "http://localhost:3001"; + + // Uses the https://github.com/tzolov/mcp-everything-server-docker-image + @SuppressWarnings("resource") + GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") + .withCommand("node dist/index.js streamableHttp") + .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) + .withExposedPorts(3001) + .waitingFor(Wait.forHttp("/").forStatusCode(404)); + + @Override + protected McpClientTransport createMcpTransport() { + return WebClientStreamableHttpTransport.builder(WebClient.builder().baseUrl(host)).build(); + } + + @Override + protected void onStart() { + container.start(); + int port = container.getMappedPort(3001); + host = "http://" + container.getHost() + ":" + port; + } + + @Override + public void onClose() { + container.stop(); + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java index b43c14493..f0533cb49 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java @@ -26,7 +26,8 @@ class WebFluxSseMcpAsyncClientTests extends AbstractMcpAsyncClientTests { // Uses the https://github.com/tzolov/mcp-everything-server-docker-image @SuppressWarnings("resource") - GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v1") + GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") + .withCommand("node dist/index.js sse") .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) .withExposedPorts(3001) .waitingFor(Wait.forHttp("/").forStatusCode(404)); diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java index 66ac8a6dd..9b0959a35 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java @@ -26,7 +26,8 @@ class WebFluxSseMcpSyncClientTests extends AbstractMcpSyncClientTests { // Uses the https://github.com/tzolov/mcp-everything-server-docker-image @SuppressWarnings("resource") - GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v1") + GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") + .withCommand("node dist/index.js sse") .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) .withExposedPorts(3001) .waitingFor(Wait.forHttp("/").forStatusCode(404)); diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java index c757d3da9..42b91d14e 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java @@ -41,7 +41,8 @@ class WebFluxSseClientTransportTests { static String host = "http://localhost:3001"; @SuppressWarnings("resource") - GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v1") + GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") + .withCommand("node dist/index.js sse") .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) .withExposedPorts(3001) .waitingFor(Wait.forHttp("/").forStatusCode(404)); diff --git a/mcp-spring/mcp-spring-webflux/src/test/resources/logback.xml b/mcp-spring/mcp-spring-webflux/src/test/resources/logback.xml index 5ad73374a..abc831d13 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/resources/logback.xml +++ b/mcp-spring/mcp-spring-webflux/src/test/resources/logback.xml @@ -9,13 +9,13 @@ - + - - + + - + diff --git a/mcp-test/pom.xml b/mcp-test/pom.xml index a6e5bdb08..9998569dc 100644 --- a/mcp-test/pom.xml +++ b/mcp-test/pom.xml @@ -68,6 +68,11 @@ junit-jupiter ${testcontainers.version} + + org.testcontainers + toxiproxy + ${toxiproxy.version} + org.awaitility diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java new file mode 100644 index 000000000..85d6a88e4 --- /dev/null +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java @@ -0,0 +1,222 @@ +package io.modelcontextprotocol.client; + +import eu.rekawek.toxiproxy.Proxy; +import eu.rekawek.toxiproxy.ToxiproxyClient; +import eu.rekawek.toxiproxy.model.ToxicDirection; +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpTransport; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.Network; +import org.testcontainers.containers.ToxiproxyContainer; +import org.testcontainers.containers.wait.strategy.Wait; +import reactor.test.StepVerifier; + +import java.io.IOException; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.function.Function; + +import static org.assertj.core.api.Assertions.assertThatCode; + +/** + * Resiliency test suite for the {@link McpAsyncClient} that can be used with different + * {@link McpTransport} implementations that support Streamable HTTP. + * + * The purpose of these tests is to allow validating the transport layer resiliency + * instead of the functionality offered by the logical layer of MCP concepts such as + * tools, resources, prompts, etc. + * + * @author Dariusz Jędrzejczyk + */ +public abstract class AbstractMcpAsyncClientResiliencyTests { + + private static final Logger logger = LoggerFactory.getLogger(AbstractMcpAsyncClientResiliencyTests.class); + + static Network network = Network.newNetwork(); + static String host = "http://localhost:3001"; + + // Uses the https://github.com/tzolov/mcp-everything-server-docker-image + @SuppressWarnings("resource") + static GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") + .withCommand("node dist/index.js streamableHttp") + .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) + .withNetwork(network) + .withNetworkAliases("everything-server") + .withExposedPorts(3001) + .waitingFor(Wait.forHttp("/").forStatusCode(404)); + + static ToxiproxyContainer toxiproxy = new ToxiproxyContainer("ghcr.io/shopify/toxiproxy:2.5.0").withNetwork(network) + .withExposedPorts(8474, 3000); + + static Proxy proxy; + + static { + container.start(); + + toxiproxy.start(); + + final ToxiproxyClient toxiproxyClient = new ToxiproxyClient(toxiproxy.getHost(), toxiproxy.getControlPort()); + try { + proxy = toxiproxyClient.createProxy("everything-server", "0.0.0.0:3000", "everything-server:3001"); + } + catch (IOException e) { + throw new RuntimeException("Can't create proxy!", e); + } + + final String ipAddressViaToxiproxy = toxiproxy.getHost(); + final int portViaToxiproxy = toxiproxy.getMappedPort(3000); + + host = "http://" + ipAddressViaToxiproxy + ":" + portViaToxiproxy; + } + + private static void disconnect() { + long start = System.nanoTime(); + try { + // disconnect + // proxy.toxics().bandwidth("CUT_CONNECTION_DOWNSTREAM", + // ToxicDirection.DOWNSTREAM, 0); + // proxy.toxics().bandwidth("CUT_CONNECTION_UPSTREAM", + // ToxicDirection.UPSTREAM, 0); + proxy.toxics().resetPeer("RESET_DOWNSTREAM", ToxicDirection.DOWNSTREAM, 0); + proxy.toxics().resetPeer("RESET_UPSTREAM", ToxicDirection.UPSTREAM, 0); + logger.info("Disconnect took {} ms", Duration.ofNanos(System.nanoTime() - start).toMillis()); + } + catch (IOException e) { + throw new RuntimeException("Failed to disconnect", e); + } + } + + private static void reconnect() { + long start = System.nanoTime(); + try { + proxy.toxics().get("RESET_UPSTREAM").remove(); + proxy.toxics().get("RESET_DOWNSTREAM").remove(); + // proxy.toxics().get("CUT_CONNECTION_DOWNSTREAM").remove(); + // proxy.toxics().get("CUT_CONNECTION_UPSTREAM").remove(); + logger.info("Reconnect took {} ms", Duration.ofNanos(System.nanoTime() - start).toMillis()); + } + catch (IOException e) { + throw new RuntimeException("Failed to reconnect", e); + } + } + + private static void restartMcpServer() { + container.stop(); + container.start(); + } + + abstract McpClientTransport createMcpTransport(); + + protected Duration getRequestTimeout() { + return Duration.ofSeconds(14); + } + + protected Duration getInitializationTimeout() { + return Duration.ofSeconds(2); + } + + McpAsyncClient client(McpClientTransport transport) { + return client(transport, Function.identity()); + } + + McpAsyncClient client(McpClientTransport transport, Function customizer) { + AtomicReference client = new AtomicReference<>(); + + assertThatCode(() -> { + McpClient.AsyncSpec builder = McpClient.async(transport) + .requestTimeout(getRequestTimeout()) + .initializationTimeout(getInitializationTimeout()) + .capabilities(McpSchema.ClientCapabilities.builder().roots(true).build()); + builder = customizer.apply(builder); + client.set(builder.build()); + }).doesNotThrowAnyException(); + + return client.get(); + } + + void withClient(McpClientTransport transport, Consumer c) { + withClient(transport, Function.identity(), c); + } + + void withClient(McpClientTransport transport, Function customizer, + Consumer c) { + var client = client(transport, customizer); + try { + c.accept(client); + } + finally { + StepVerifier.create(client.closeGracefully()).expectComplete().verify(Duration.ofSeconds(10)); + } + } + + @Test + void testPing() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete(); + + disconnect(); + + StepVerifier.create(mcpAsyncClient.ping()).expectError().verify(); + + reconnect(); + + StepVerifier.create(mcpAsyncClient.ping()).expectNextCount(1).verifyComplete(); + }); + } + + @Test + void testSessionInvalidation() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete(); + + restartMcpServer(); + + // The first try will face the session mismatch exception and the second one + // will go through the re-initialization process. + StepVerifier.create(mcpAsyncClient.ping().retry(1)).expectNextCount(1).verifyComplete(); + }); + } + + @Test + void testCallTool() { + withClient(createMcpTransport(), mcpAsyncClient -> { + AtomicReference> tools = new AtomicReference<>(); + StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete(); + StepVerifier.create(mcpAsyncClient.listTools()) + .consumeNextWith(list -> tools.set(list.tools())) + .verifyComplete(); + + disconnect(); + + String name = tools.get().get(0).name(); + // Assuming this is the echo tool + McpSchema.CallToolRequest request = new McpSchema.CallToolRequest(name, Map.of("message", "hello")); + StepVerifier.create(mcpAsyncClient.callTool(request)).expectError().verify(); + + reconnect(); + + StepVerifier.create(mcpAsyncClient.callTool(request)).expectNextCount(1).verifyComplete(); + }); + } + + @Test + void testSessionClose() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete(); + // In case of Streamable HTTP this call should issue a HTTP DELETE request + // invalidating the session + StepVerifier.create(mcpAsyncClient.closeGracefully()).expectComplete().verify(); + // The next use should immediately re-initialize with no issue and send the + // request without any broken connections. + StepVerifier.create(mcpAsyncClient.ping()).expectNextCount(1).verifyComplete(); + }); + } + +} diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java index 5452c8eac..049bea008 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java @@ -110,14 +110,16 @@ void tearDown() { onClose(); } - void verifyInitializationTimeout(Function> operation, String action) { + void verifyNotificationSucceedsWithImplicitInitialization(Function> operation, + String action) { withClient(createMcpTransport(), mcpAsyncClient -> { - StepVerifier.withVirtualTime(() -> operation.apply(mcpAsyncClient)) - .expectSubscription() - .thenAwait(getInitializationTimeout()) - .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before " + action)) - .verify(); + StepVerifier.create(operation.apply(mcpAsyncClient)).verifyComplete(); + }); + } + + void verifyCallSucceedsWithImplicitInitialization(Function> operation, String action) { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(operation.apply(mcpAsyncClient)).expectNextCount(1).verifyComplete(); }); } @@ -133,7 +135,7 @@ void testConstructorWithInvalidArguments() { @Test void testListToolsWithoutInitialization() { - verifyInitializationTimeout(client -> client.listTools(null), "listing tools"); + verifyCallSucceedsWithImplicitInitialization(client -> client.listTools(null), "listing tools"); } @Test @@ -153,7 +155,7 @@ void testListTools() { @Test void testPingWithoutInitialization() { - verifyInitializationTimeout(client -> client.ping(), "pinging the server"); + verifyCallSucceedsWithImplicitInitialization(client -> client.ping(), "pinging the server"); } @Test @@ -168,7 +170,7 @@ void testPing() { @Test void testCallToolWithoutInitialization() { CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", ECHO_TEST_MESSAGE)); - verifyInitializationTimeout(client -> client.callTool(callToolRequest), "calling tools"); + verifyCallSucceedsWithImplicitInitialization(client -> client.callTool(callToolRequest), "calling tools"); } @Test @@ -202,7 +204,7 @@ void testCallToolWithInvalidTool() { @Test void testListResourcesWithoutInitialization() { - verifyInitializationTimeout(client -> client.listResources(null), "listing resources"); + verifyCallSucceedsWithImplicitInitialization(client -> client.listResources(null), "listing resources"); } @Test @@ -233,7 +235,7 @@ void testMcpAsyncClientState() { @Test void testListPromptsWithoutInitialization() { - verifyInitializationTimeout(client -> client.listPrompts(null), "listing " + "prompts"); + verifyCallSucceedsWithImplicitInitialization(client -> client.listPrompts(null), "listing " + "prompts"); } @Test @@ -258,7 +260,7 @@ void testListPrompts() { @Test void testGetPromptWithoutInitialization() { GetPromptRequest request = new GetPromptRequest("simple_prompt", Map.of()); - verifyInitializationTimeout(client -> client.getPrompt(request), "getting " + "prompts"); + verifyCallSucceedsWithImplicitInitialization(client -> client.getPrompt(request), "getting " + "prompts"); } @Test @@ -279,7 +281,7 @@ void testGetPrompt() { @Test void testRootsListChangedWithoutInitialization() { - verifyInitializationTimeout(client -> client.rootsListChangedNotification(), + verifyNotificationSucceedsWithImplicitInitialization(client -> client.rootsListChangedNotification(), "sending roots list changed notification"); } @@ -354,7 +356,8 @@ void testReadResource() { @Test void testListResourceTemplatesWithoutInitialization() { - verifyInitializationTimeout(client -> client.listResourceTemplates(), "listing resource templates"); + verifyCallSucceedsWithImplicitInitialization(client -> client.listResourceTemplates(), + "listing resource templates"); } @Test @@ -447,8 +450,8 @@ void testInitializeWithAllCapabilities() { @Test void testLoggingLevelsWithoutInitialization() { - verifyInitializationTimeout(client -> client.setLoggingLevel(McpSchema.LoggingLevel.DEBUG), - "setting logging level"); + verifyNotificationSucceedsWithImplicitInitialization( + client -> client.setLoggingLevel(McpSchema.LoggingLevel.DEBUG), "setting logging level"); } @Test diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java index 128441f80..3785fd645 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java @@ -5,6 +5,7 @@ package io.modelcontextprotocol.client; import java.time.Duration; +import java.util.List; import java.util.Map; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; @@ -12,7 +13,6 @@ import java.util.function.Function; import io.modelcontextprotocol.spec.McpClientTransport; -import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; @@ -112,33 +112,18 @@ void tearDown() { static final Object DUMMY_RETURN_VALUE = new Object(); - void verifyNotificationTimesOut(Consumer operation, String action) { - verifyCallTimesOut(client -> { + void verifyNotificationSucceedsWithImplicitInitialization(Consumer operation, String action) { + verifyCallSucceedsWithImplicitInitialization(client -> { operation.accept(client); return DUMMY_RETURN_VALUE; }, action); } - void verifyCallTimesOut(Function blockingOperation, String action) { + void verifyCallSucceedsWithImplicitInitialization(Function blockingOperation, String action) { withClient(createMcpTransport(), mcpSyncClient -> { - // This scheduler is not replaced by virtual time scheduler - Scheduler customScheduler = Schedulers.newBoundedElastic(1, 1, "actualBoundedElastic"); - - StepVerifier.withVirtualTime(() -> Mono.fromSupplier(() -> blockingOperation.apply(mcpSyncClient)) - // Offload the blocking call to the real scheduler - .subscribeOn(customScheduler)) - .expectSubscription() - // This works without actually waiting but executes all the - // tasks pending execution on the VirtualTimeScheduler. - // It is possible to execute the blocking code from the operation - // because it is blocked on a dedicated Scheduler and the main - // flow is not blocked and uses the VirtualTimeScheduler. - .thenAwait(getInitializationTimeout()) - .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before " + action)) - .verify(); - - customScheduler.dispose(); + StepVerifier.create(Mono.fromSupplier(() -> blockingOperation.apply(mcpSyncClient))) + .expectNextCount(1) + .verifyComplete(); }); } @@ -154,7 +139,7 @@ void testConstructorWithInvalidArguments() { @Test void testListToolsWithoutInitialization() { - verifyCallTimesOut(client -> client.listTools(null), "listing tools"); + verifyCallSucceedsWithImplicitInitialization(client -> client.listTools(null), "listing tools"); } @Test @@ -175,8 +160,8 @@ void testListTools() { @Test void testCallToolsWithoutInitialization() { - verifyCallTimesOut(client -> client.callTool(new CallToolRequest("add", Map.of("a", 3, "b", 4))), - "calling tools"); + verifyCallSucceedsWithImplicitInitialization( + client -> client.callTool(new CallToolRequest("add", Map.of("a", 3, "b", 4))), "calling tools"); } @Test @@ -200,7 +185,7 @@ void testCallTools() { @Test void testPingWithoutInitialization() { - verifyCallTimesOut(client -> client.ping(), "pinging the server"); + verifyCallSucceedsWithImplicitInitialization(client -> client.ping(), "pinging the server"); } @Test @@ -214,7 +199,7 @@ void testPing() { @Test void testCallToolWithoutInitialization() { CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", TEST_MESSAGE)); - verifyCallTimesOut(client -> client.callTool(callToolRequest), "calling tools"); + verifyCallSucceedsWithImplicitInitialization(client -> client.callTool(callToolRequest), "calling tools"); } @Test @@ -243,7 +228,7 @@ void testCallToolWithInvalidTool() { @Test void testRootsListChangedWithoutInitialization() { - verifyNotificationTimesOut(client -> client.rootsListChangedNotification(), + verifyNotificationSucceedsWithImplicitInitialization(client -> client.rootsListChangedNotification(), "sending roots list changed notification"); } @@ -257,7 +242,7 @@ void testRootsListChanged() { @Test void testListResourcesWithoutInitialization() { - verifyCallTimesOut(client -> client.listResources(null), "listing resources"); + verifyCallSucceedsWithImplicitInitialization(client -> client.listResources(null), "listing resources"); } @Test @@ -333,8 +318,14 @@ void testRemoveNonExistentRoot() { @Test void testReadResourceWithoutInitialization() { - Resource resource = new Resource("test://uri", "Test Resource", null, null, null); - verifyCallTimesOut(client -> client.readResource(resource), "reading resources"); + AtomicReference> resources = new AtomicReference<>(); + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + resources.set(mcpSyncClient.listResources().resources()); + }); + + verifyCallSucceedsWithImplicitInitialization(client -> client.readResource(resources.get().get(0)), + "reading resources"); } @Test @@ -355,7 +346,8 @@ void testReadResource() { @Test void testListResourceTemplatesWithoutInitialization() { - verifyCallTimesOut(client -> client.listResourceTemplates(null), "listing resource templates"); + verifyCallSucceedsWithImplicitInitialization(client -> client.listResourceTemplates(null), + "listing resource templates"); } @Test @@ -413,8 +405,8 @@ void testNotificationHandlers() { @Test void testLoggingLevelsWithoutInitialization() { - verifyNotificationTimesOut(client -> client.setLoggingLevel(McpSchema.LoggingLevel.DEBUG), - "setting logging level"); + verifyNotificationSucceedsWithImplicitInitialization( + client -> client.setLoggingLevel(McpSchema.LoggingLevel.DEBUG), "setting logging level"); } @Test diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java index a22ef6b51..8f0433eb1 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java @@ -9,9 +9,9 @@ import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.TimeoutException; -import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; +import java.util.function.Supplier; import com.fasterxml.jackson.core.type.TypeReference; import io.modelcontextprotocol.spec.McpClientSession; @@ -32,7 +32,7 @@ import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; import io.modelcontextprotocol.spec.McpSchema.PaginatedRequest; import io.modelcontextprotocol.spec.McpSchema.Root; -import io.modelcontextprotocol.spec.McpTransport; +import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException; import io.modelcontextprotocol.util.Assert; import io.modelcontextprotocol.util.Utils; import org.slf4j.Logger; @@ -77,29 +77,37 @@ * @see McpClient * @see McpSchema * @see McpClientSession + * @see McpClientTransport */ public class McpAsyncClient { private static final Logger logger = LoggerFactory.getLogger(McpAsyncClient.class); - private static TypeReference VOID_TYPE_REFERENCE = new TypeReference<>() { + private static final TypeReference VOID_TYPE_REFERENCE = new TypeReference<>() { }; - protected final Sinks.One initializedSink = Sinks.one(); + public static final TypeReference OBJECT_TYPE_REF = new TypeReference<>() { + }; + + public static final TypeReference PAGINATED_REQUEST_TYPE_REF = new TypeReference<>() { + }; + + public static final TypeReference INITIALIZE_RESULT_TYPE_REF = new TypeReference<>() { + }; + + public static final TypeReference CREATE_MESSAGE_REQUEST_TYPE_REF = new TypeReference<>() { + }; + + public static final TypeReference LOGGING_MESSAGE_NOTIFICATION_TYPE_REF = new TypeReference<>() { + }; - private AtomicBoolean initialized = new AtomicBoolean(false); + private final AtomicReference initializationRef = new AtomicReference<>(); /** * The max timeout to await for the client-server connection to be initialized. */ private final Duration initializationTimeout; - /** - * The MCP session implementation that manages bidirectional JSON-RPC communication - * between clients and servers. - */ - private final McpClientSession mcpSession; - /** * Client capabilities. */ @@ -110,21 +118,6 @@ public class McpAsyncClient { */ private final McpSchema.Implementation clientInfo; - /** - * Server capabilities. - */ - private McpSchema.ServerCapabilities serverCapabilities; - - /** - * Server instructions. - */ - private String serverInstructions; - - /** - * Server implementation information. - */ - private McpSchema.Implementation serverInfo; - /** * Roots define the boundaries of where servers can operate within the filesystem, * allowing them to understand which directories and files they have access to. @@ -155,13 +148,19 @@ public class McpAsyncClient { /** * Client transport implementation. */ - private final McpTransport transport; + private final McpClientTransport transport; /** * Supported protocol versions. */ private List protocolVersions = List.of(McpSchema.LATEST_PROTOCOL_VERSION); + /** + * The MCP session supplier that manages bidirectional JSON-RPC communication between + * clients and servers. + */ + private final Supplier sessionSupplier; + /** * Create a new McpAsyncClient with the given transport and session request-response * timeout. @@ -254,8 +253,29 @@ public class McpAsyncClient { notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_MESSAGE, asyncLoggingNotificationHandler(loggingConsumersFinal)); - this.mcpSession = new McpClientSession(requestTimeout, transport, requestHandlers, notificationHandlers); + this.transport.setExceptionHandler(this::handleException); + this.sessionSupplier = () -> new McpClientSession(requestTimeout, transport, requestHandlers, + notificationHandlers); + + } + + private void handleException(Throwable t) { + logger.warn("Handling exception", t); + if (t instanceof McpTransportSessionNotFoundException) { + Initialization previous = this.initializationRef.getAndSet(null); + if (previous != null) { + previous.close(); + } + // Providing an empty operation since we are only interested in triggering the + // implicit initialization step. + withSession("re-initializing", result -> Mono.empty()).subscribe(); + } + } + private McpSchema.InitializeResult currentInitializationResult() { + Initialization current = this.initializationRef.get(); + McpSchema.InitializeResult initializeResult = current != null ? current.result.get() : null; + return initializeResult; } /** @@ -263,7 +283,8 @@ public class McpAsyncClient { * @return The server capabilities */ public McpSchema.ServerCapabilities getServerCapabilities() { - return this.serverCapabilities; + McpSchema.InitializeResult initializeResult = currentInitializationResult(); + return initializeResult != null ? initializeResult.capabilities() : null; } /** @@ -272,7 +293,8 @@ public McpSchema.ServerCapabilities getServerCapabilities() { * @return The server instructions */ public String getServerInstructions() { - return this.serverInstructions; + McpSchema.InitializeResult initializeResult = currentInitializationResult(); + return initializeResult != null ? initializeResult.instructions() : null; } /** @@ -280,7 +302,8 @@ public String getServerInstructions() { * @return The server implementation details */ public McpSchema.Implementation getServerInfo() { - return this.serverInfo; + McpSchema.InitializeResult initializeResult = currentInitializationResult(); + return initializeResult != null ? initializeResult.serverInfo() : null; } /** @@ -288,7 +311,8 @@ public McpSchema.Implementation getServerInfo() { * @return true if the client-server connection is initialized */ public boolean isInitialized() { - return this.initialized.get(); + Initialization current = this.initializationRef.get(); + return current != null && (current.result.get() != null); } /** @@ -311,7 +335,11 @@ public McpSchema.Implementation getClientInfo() { * Closes the client connection immediately. */ public void close() { - this.mcpSession.close(); + Initialization current = this.initializationRef.getAndSet(null); + if (current != null) { + current.close(); + } + this.transport.close(); } /** @@ -319,14 +347,21 @@ public void close() { * @return A Mono that completes when the connection is closed */ public Mono closeGracefully() { - return this.mcpSession.closeGracefully(); + return Mono.defer(() -> { + Initialization current = this.initializationRef.getAndSet(null); + Mono sessionClose = current != null ? current.closeGracefully() : Mono.empty(); + return sessionClose.then(transport.closeGracefully()); + }); } // -------------------------- // Initialization // -------------------------- /** - * The initialization phase MUST be the first interaction between client and server. + * The initialization phase should be the first interaction between client and server. + * The client will ensure it happens in case it has not been explicitly called and in + * case of transport session invalidation. + *

* During this phase, the client and server: *